aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/ops.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-15 17:20:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-15 18:31:54 -0700
commite29d94947a7ef28b4917b22c019aacbe7dacc6db (patch)
tree81e25bd22ca82f8a004bc169f8ec33a99325cb99 /tensorflow/python/framework/ops.py
parent93b8a053c01c1b388394f27cb63bbf460b82f062 (diff)
Add an experimental tf.Graph._attr_scope() that can be used to tag nodes with additional attributes.
Change: 133334381
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r--tensorflow/python/framework/ops.py71
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index caa9695c8a..3a066931ca 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -2036,6 +2036,8 @@ class Graph(object):
self._collections = {}
# The graph-level random seed
self._seed = None
+ # A dictionary of attributes that should be applied to all ops.
+ self._attr_scope_map = {}
# A map from op type to the kernel label that should be used.
self._op_to_kernel_label_map = {}
# A map from op type to an alternative op type that should be used when
@@ -2349,6 +2351,12 @@ class Graph(object):
node_def = _NodeDef(op_type, name, device=None, attrs=attrs)
+ # Apply any additional attributes requested. Do not overwrite any existing
+ # attributes.
+ for key, value in self._attr_scope_map.items():
+ if key not in node_def.attr:
+ node_def.attr[key].CopyFrom(value)
+
# Apply a kernel label if one has been specified for this op_type.
try:
kernel_label = self._op_to_kernel_label_map[op_type]
@@ -3347,6 +3355,69 @@ class Graph(object):
# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
+ def _attr_scope(self, attr_map):
+ """EXPERIMENTAL: A context manager for setting attributes on operators.
+
+ This context manager can be used to add additional
+ attributes to operators within the scope of the context.
+
+ For example:
+
+ with ops.Graph().as_default() as g:
+ f_1 = Foo() # No extra attributes
+ with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}):
+ f_2 = Foo() # Additional attribute _a=False
+ with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}):
+ f_3 = Foo() # Additional attribute _a=False
+ with g._attr_scope({"_a": None}):
+ f_4 = Foo() # No additional attributes.
+
+ Args:
+ attr_map: A dictionary mapping attr name strings to
+ AttrValue protocol buffers or None.
+
+ Returns:
+ A context manager that sets the kernel label to be used for one or more
+ ops created in that context.
+
+ Raises:
+ TypeError: If attr_map is not a dictionary mapping
+ strings to AttrValue protobufs.
+ """
+ if not isinstance(attr_map, dict):
+ raise TypeError("attr_map must be a dictionary mapping "
+ "strings to AttrValue protocol buffers")
+ # The saved_attrs dictionary stores any currently-set labels that
+ # will be overridden by this context manager.
+ saved_attrs = {}
+ # Install the given attribute
+ for name, attr in attr_map.items():
+ if not (isinstance(name, six.string_types)
+ and isinstance(attr, (type(None), attr_value_pb2.AttrValue))):
+ raise TypeError("attr_map must be a dictionary mapping "
+ "strings to AttrValue protocol buffers")
+ try:
+ saved_attrs[name] = self._attr_scope_map[name]
+ except KeyError:
+ pass
+ if attr is None:
+ del self._attr_scope_map[name]
+ else:
+ self._attr_scope_map[name] = attr
+ try:
+ yield # The code within the context runs here.
+ finally:
+ # Remove the attributes set for this context, and restore any saved
+ # attributes.
+ for name, attr in attr_map.items():
+ try:
+ self._attr_scope_map[name] = saved_attrs[name]
+ except KeyError:
+ del self._attr_scope_map[name]
+ # pylint: enable=g-doc-return-or-yield
+
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
def _kernel_label_map(self, op_to_kernel_label_map):
"""EXPERIMENTAL: A context manager for setting kernel labels.