aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework')
-rwxr-xr-xtensorflow/python/framework/__init__.py0
-rw-r--r--tensorflow/python/framework/device.py220
-rw-r--r--tensorflow/python/framework/device_test.py122
-rw-r--r--tensorflow/python/framework/docs.py492
-rw-r--r--tensorflow/python/framework/errors.py410
-rw-r--r--tensorflow/python/framework/errors_test.py63
-rw-r--r--tensorflow/python/framework/framework_lib.py70
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py114
-rwxr-xr-xtensorflow/python/framework/gen_docs_test.sh4
-rw-r--r--tensorflow/python/framework/importer.py303
-rw-r--r--tensorflow/python/framework/importer_test.py546
-rw-r--r--tensorflow/python/framework/op_def_registry.py23
-rw-r--r--tensorflow/python/framework/ops.py2985
-rw-r--r--tensorflow/python/framework/ops_test.py825
-rw-r--r--tensorflow/python/framework/python_op_gen.cc678
-rw-r--r--tensorflow/python/framework/python_op_gen.h17
-rw-r--r--tensorflow/python/framework/python_op_gen_main.cc30
-rw-r--r--tensorflow/python/framework/random_seed.py136
-rw-r--r--tensorflow/python/framework/registry.py64
-rw-r--r--tensorflow/python/framework/registry_test.py38
-rw-r--r--tensorflow/python/framework/tensor_shape.py743
-rw-r--r--tensorflow/python/framework/tensor_shape_test.py232
-rw-r--r--tensorflow/python/framework/tensor_util.py511
-rw-r--r--tensorflow/python/framework/tensor_util_test.py379
-rw-r--r--tensorflow/python/framework/test_kernel_label_op.cc47
-rw-r--r--tensorflow/python/framework/test_util.py437
-rw-r--r--tensorflow/python/framework/test_util_test.py128
-rw-r--r--tensorflow/python/framework/types.py418
-rw-r--r--tensorflow/python/framework/types_test.py174
29 files changed, 10209 insertions, 0 deletions
diff --git a/tensorflow/python/framework/__init__.py b/tensorflow/python/framework/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/framework/__init__.py
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
new file mode 100644
index 0000000000..676e5f779a
--- /dev/null
+++ b/tensorflow/python/framework/device.py
@@ -0,0 +1,220 @@
+"""Class to represent a device."""
+import copy
+
+
+class Device(object):
+ """Represents a Device."""
+
+ def __init__(self, job=None, replica=None, task=None, device_type=None,
+ device_index=None):
+ """Create a new device object.
+
+ Args:
+ job: string. Optional device job name.
+ replica: int. Optional replica index.
+ task: int. Optional task index.
+ device_type: Optional device type string (e.g. "CPU" or "GPU")
+ device_index: int. Optional device index. If left
+ unspecified, device represents 'any' device_index.
+ """
+ self.job = job
+ self.replica = replica
+ self.task = task
+ if device_type == "cpu" or device_type == "gpu":
+ # For backwards compatibility only, we support lowercase variants of
+ # cpu and gpu but turn them into uppercase here.
+ self.device_type = device_type.upper()
+ else:
+ self.device_type = device_type
+ self.device_index = device_index
+
+ def _clear(self):
+ self._job = None
+ self._replica = None
+ self._task = None
+ self.device_type = None
+ self.device_index = None
+
+ @property
+ def job(self):
+ return self._job
+
+ @job.setter
+ def job(self, job):
+ if job is not None:
+ self._job = str(job)
+ else:
+ self._job = None
+
+ @property
+ def replica(self):
+ return self._replica
+
+ @replica.setter
+ def replica(self, replica):
+ if replica is not None:
+ self._replica = int(replica)
+ else:
+ self._replica = None
+
+ @property
+ def task(self):
+ return self._task
+
+ @task.setter
+ def task(self, task):
+ if task is not None:
+ self._task = int(task)
+ else:
+ self._task = None
+
+ def parse_from_string(self, spec):
+ """Parse a Device name into its components.
+
+ Args:
+ spec: a string of the form
+ /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
+ or
+ /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
+ as cpu and gpu are mutually exclusive.
+ All entries are optional.
+
+ Returns:
+ The Device, for convenience.
+
+ Raises:
+ ValueError: if the spec was not valid.
+ """
+ self._clear()
+ splits = [x.split(":") for x in spec.split("/")]
+ for y in splits:
+ ly = len(y)
+ if y:
+ # NOTE(mdevin): we use the property getters here.
+ if ly == 2 and y[0] == "job":
+ self.job = y[1]
+ elif ly == 2 and y[0] == "replica":
+ self.replica = y[1]
+ elif ly == 2 and y[0] == "task":
+ self.task = y[1]
+ elif ((ly == 1 or ly == 2) and
+ ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))):
+ if self.device_type is not None:
+ raise ValueError("Cannot specify multiple device types: %s" % spec)
+ self.device_type = y[0].upper()
+ if ly == 2 and y[1] != "*":
+ self.device_index = int(y[1])
+ elif ly == 3 and y[0] == "device":
+ if self.device_type is not None:
+ raise ValueError("Cannot specify multiple device types: %s" % spec)
+ self.device_type = y[1]
+ if y[2] != "*":
+ self.device_index = int(y[2])
+ elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison
+ raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))
+
+ return self
+
+ def merge_from(self, dev):
+ """Merge the properties of "dev" into this Device.
+
+ Args:
+ dev: a Device.
+ """
+ if dev.job is not None:
+ self.job = dev.job
+ if dev.replica is not None:
+ self.replica = dev.replica
+ if dev.task is not None:
+ self.task = dev.task
+ if dev.device_type is not None:
+ self.device_type = dev.device_type
+ if dev.device_index is not None:
+ self.device_index = dev.device_index
+
+ def to_string(self):
+ """Return a Device specification string.
+
+ Returns:
+ a string of the form /job:<name>/replica:<id>/task:<id>/device:cpu:<id>
+ or /job:<name>/replica:<id>/task:<id>/device:cpu:<id>.
+ """
+ dev = ""
+ if self.job is not None:
+ dev += "/job:" + self.job
+ if self.replica is not None:
+ dev += "/replica:" + str(self.replica)
+ if self.task is not None:
+ dev += "/task:" + str(self.task)
+ if self.device_type is not None:
+ device_index_string = "*"
+ if self.device_index is not None:
+ device_index_string = str(self.device_index)
+ dev += "/device:%s:%s" % (self.device_type, device_index_string)
+ return dev
+
+
+def from_string(spec):
+ """Construct a Device from a string.
+
+ Args:
+ spec: a string of the form
+ /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
+ or
+ /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
+ as cpu and gpu are mutually exclusive.
+ All entries are optional.
+
+ Returns:
+ A Device.
+ """
+ return Device().parse_from_string(spec)
+
+
+def check_valid(spec):
+ """Check that a device spec is valid.
+
+ Args:
+ spec: a string.
+
+ Raises:
+ An exception if the spec is invalid.
+ """
+ # Construct a device. It will assert a failure if spec is invalid.
+ from_string(spec)
+
+
+def merge_device(spec):
+ """Returns a device function that merges devices specifications.
+
+ This can be used to merge partial specifications of devices. The
+ innermost setting for a device field takes precedence. For example:
+
+ with tf.Device(MergeDevice("/device:GPU:0"))
+ # Nodes created here have device "/device:GPU:0"
+ with tf.Device(MergeDevice("/job:worker")):
+ # Nodes created here have device "/job:worker/device:GPU:0"
+ with tf.Device(MergeDevice("/device:CPU:0")):
+ # Nodes created here have device "/job:worker/device:CPU:0"
+ with tf.Device(MergeDevice("/job:ps")):
+ # Nodes created here have device "/job:ps/device:CPU:0"
+
+ Args:
+ spec: A device or a device spec string (partially) describing the
+ device that should be used for all nodes created in the scope of
+ the returned device function's with block.
+
+ Returns:
+ A device function with the above-described behavior.
+
+ Raises:
+ ValueError: if the spec was not valid.
+ """
+ if not isinstance(spec, Device):
+ spec = from_string(spec or "")
+ def _device_function(node_def):
+ current_device = from_string(node_def.device or "")
+ copy_spec = copy.copy(spec)
+ copy_spec.merge_from(current_device) # current_device takes precedence.
+ return copy_spec
+ return _device_function
diff --git a/tensorflow/python/framework/device_test.py b/tensorflow/python/framework/device_test.py
new file mode 100644
index 0000000000..0a244b0815
--- /dev/null
+++ b/tensorflow/python/framework/device_test.py
@@ -0,0 +1,122 @@
+"""Tests for tensorflow.python.framework.device."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import device
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class DeviceTest(test_util.TensorFlowTestCase):
+
+ def testEmpty(self):
+ d = device.Device()
+ self.assertEquals("", d.ToString())
+ d.parse_from_string("")
+ self.assertEquals("", d.ToString())
+
+ def testConstructor(self):
+ d = device.Device(job="j", replica=0, task=1,
+ device_type="CPU", device_index=2)
+ self.assertEquals("j", d.job)
+ self.assertEquals(0, d.replica)
+ self.assertEquals(1, d.task)
+ self.assertEquals("CPU", d.device_type)
+ self.assertEquals(2, d.device_index)
+ self.assertEquals("/job:j/replica:0/task:1/device:CPU:2", d.to_string())
+
+ d = device.Device(device_type="GPU", device_index=0)
+ self.assertEquals("/device:GPU:0", d.to_string())
+
+ def testto_string(self):
+ d = device.Device()
+ d.job = "foo"
+ self.assertEquals("/job:foo", d.to_string())
+ d.task = 3
+ self.assertEquals("/job:foo/task:3", d.to_string())
+ d.device_type = "CPU"
+ d.device_index = 0
+ self.assertEquals("/job:foo/task:3/device:CPU:0", d.to_string())
+ d.task = None
+ d.replica = 12
+ self.assertEquals("/job:foo/replica:12/device:CPU:0", d.to_string())
+ d.device_type = "GPU"
+ d.device_index = 2
+ self.assertEquals("/job:foo/replica:12/device:GPU:2", d.to_string())
+ d.device_type = "CPU"
+ d.device_index = 1
+ self.assertEquals("/job:foo/replica:12/device:CPU:1", d.to_string())
+ d.device_type = None
+ d.device_index = None
+ d.cpu = None
+ self.assertEquals("/job:foo/replica:12", d.to_string())
+
+ # Test wildcard
+ d = device.Device(job="foo", replica=12, task=3, device_type="GPU")
+ self.assertEquals("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
+
+ def testParse(self):
+ d = device.Device()
+ d.parse_from_string("/job:foo/replica:0")
+ self.assertEquals("/job:foo/replica:0", d.to_string())
+ d.parse_from_string("/replica:1/task:0/cpu:0")
+ self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string())
+ d.parse_from_string("/replica:1/task:0/device:CPU:0")
+ self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string())
+ d.parse_from_string("/job:muu/gpu:2")
+ self.assertEquals("/job:muu/device:GPU:2", d.to_string())
+ with self.assertRaises(Exception) as e:
+ d.parse_from_string("/job:muu/gpu:2/cpu:0")
+ self.assertTrue("Cannot specify multiple device" in e.exception.message)
+
+ def testFromString(self):
+ d = device.from_string("/job:foo/replica:0")
+ self.assertEquals("/job:foo/replica:0", d.to_string())
+ with self.assertRaises(Exception) as e:
+ d = device.from_string("/job:muu/gpu:2/cpu:0")
+ self.assertTrue("Cannot specify multiple device" in e.exception.message)
+
+ d = device.from_string("/job:foo/replica:0/task:3/cpu:*")
+ self.assertEquals(None, d.device_index)
+ d = device.from_string("/job:foo/replica:0/task:3/gpu:7")
+ self.assertEquals(7, d.device_index)
+ d = device.from_string("/job:foo/replica:0/task:3/device:GPU:7")
+ self.assertEquals(7, d.device_index)
+
+ def testMerge(self):
+ d = device.from_string("/job:foo/replica:0")
+ self.assertEquals("/job:foo/replica:0", d.to_string())
+ d.merge_from(device.from_string("/task:1/gpu:2"))
+ self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())
+
+ d = device.Device()
+ d.merge_from(device.from_string("/task:1/cpu:0"))
+ self.assertEquals("/task:1/device:CPU:0", d.to_string())
+ d.merge_from(device.from_string("/job:boo/gpu:0"))
+ self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string())
+ d.merge_from(device.from_string("/job:muu/cpu:2"))
+ self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string())
+ d.merge_from(device.from_string("/job:muu/device:MyFunnyDevice:2"))
+ self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
+
+ def testCheckValid(self):
+ device.CheckValid("/job:foo/replica:0")
+
+ with self.assertRaises(Exception) as e:
+ device.CheckValid("/job:j/replica:foo")
+ self.assertTrue("invalid literal for int" in e.exception.message)
+
+ with self.assertRaises(Exception) as e:
+ device.CheckValid("/job:j/task:bar")
+ self.assertTrue("invalid literal for int" in e.exception.message)
+
+ with self.assertRaises(Exception) as e:
+ device.CheckValid("/bar:muu/baz:2")
+ self.assertTrue("Unknown attribute: 'bar'" in e.exception.message)
+
+ with self.assertRaises(Exception) as e:
+ device.CheckValid("/cpu:0/gpu:2")
+ self.assertTrue("Cannot specify multiple device" in e.exception.message)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py
new file mode 100644
index 0000000000..68dbb3df72
--- /dev/null
+++ b/tensorflow/python/framework/docs.py
@@ -0,0 +1,492 @@
+"""Updates generated docs from Python doc comments.
+
+Both updates the files in the file-system and executes g4 commands to
+make sure any changes are ready to be submitted.
+"""
+
+import inspect
+import os
+import re
+import sys
+
+
+_arg_re = re.compile(" *([*]{0,2}[a-zA-Z][a-zA-Z0-9_]*):")
+_section_re = re.compile("([A-Z][a-zA-Z ]*):$")
+_always_drop_symbol_re = re.compile("_[_a-zA-Z0-9]")
+_anchor_re = re.compile(r"^[\w.]+$")
+_member_mark = "@@"
+
+
+class Document(object):
+ """Base class for an automatically generated document."""
+
+ def write_markdown_to_file(self, f):
+ """Writes a Markdown-formatted version of this document to file `f`.
+
+ Args:
+ f: The output file.
+ """
+ raise NotImplementedError("Document.WriteToFile")
+
+
+class Index(Document):
+ """An automatically generated index for a collection of documents."""
+
+ def __init__(self, module_to_name, members, filename_to_library_map):
+ """Creates a new Index.
+
+ Args:
+ module_to_name: Dictionary mapping modules to short names.
+ members: Dictionary mapping member name to (fullname, member).
+ filename_to_library_map: A list of (filename, Library) pairs. The order
+ corresponds to the order in which the libraries appear in the index.
+ """
+ self._module_to_name = module_to_name
+ self._members = members
+ self._filename_to_library_map = filename_to_library_map
+
+ def write_markdown_to_file(self, f):
+ """Writes this index to file `f`.
+
+ The output is formatted as an unordered list. Each list element
+ contains the title of the library, followed by a list of symbols
+ in that library hyperlinked to the corresponding anchor in that
+ library.
+
+ Args:
+ f: The output file.
+ """
+ print >>f, "<!-- This file is machine generated: DO NOT EDIT! -->"
+ print >>f, ""
+ print >>f, "# TensorFlow Python reference documentation"
+ print >>f, ""
+ for filename, library in self._filename_to_library_map:
+ per_symbol_links = []
+ for name in sorted(library.mentioned):
+ if name in self._members:
+ fullname, member = self._members[name]
+ anchor = _get_anchor(self._module_to_name, fullname)
+ prefix = "class " * inspect.isclass(member)
+ per_symbol_links.append("[%s%s](%s#%s)" %
+ (prefix, name, filename, anchor))
+ if per_symbol_links:
+ print >>f, "* <b>[%s](%s)</b>: %s" % (library.title, filename,
+ ",\n ".join(per_symbol_links))
+ print >>f, ""
+
+ # actually include the files right here
+ print >>f, '<div class="sections-order" style="display: none;">\n<!--'
+ for filename, _ in self._filename_to_library_map:
+ print >>f, "<!-- %s -->" % filename
+ print >>f, "-->\n</div>"
+
+def collect_members(module_to_name):
+ """Collect all symbols from a list of modules.
+
+ Args:
+ module_to_name: Dictionary mapping modules to short names.
+
+ Returns:
+ Dictionary mapping name to (fullname, member) pairs.
+ """
+ members = {}
+ for module, module_name in module_to_name.iteritems():
+ for name, member in inspect.getmembers(module):
+ if ((inspect.isfunction(member) or inspect.isclass(member)) and
+ not _always_drop_symbol_re.match(name)):
+ fullname = '%s.%s' % (module_name, name)
+ if name in members:
+ other_fullname, other_member = members[name]
+ if member is not other_member:
+ raise RuntimeError("Short name collision between %s and %s" %
+ (fullname, other_fullname))
+ if len(fullname) == len(other_fullname):
+ raise RuntimeError("Can't decide whether to use %s or %s for %s: "
+ "both full names have length %d" %
+ (fullname, other_fullname, len(fullname)))
+ if len(fullname) > len(other_fullname):
+ continue # Use the shorter full name
+ members[name] = fullname, member
+ return members
+
+
+def _get_anchor(module_to_name, fullname):
+ """Turn a full member name into an anchor.
+
+ Args:
+ module_to_name: Dictionary mapping modules to short names.
+ fullname: Fully qualified name of symbol.
+
+ Returns:
+ HTML anchor string. The longest module name prefix of fullname is
+ removed to make the anchor.
+
+ Raises:
+ ValueError: If fullname uses characters invalid in an anchor.
+ """
+ if not _anchor_re.match(fullname):
+ raise ValueError("'%s' is not a valid anchor" % fullname)
+ anchor = fullname
+ for module_name in module_to_name.itervalues():
+ if fullname.startswith(module_name + "."):
+ rest = fullname[len(module_name)+1:]
+ # Use this prefix iff it is longer than any found before
+ if len(anchor) > len(rest):
+ anchor = rest
+ return anchor
+
+
+class Library(Document):
+ """An automatically generated document for a set of functions and classes."""
+
+ def __init__(self,
+ title,
+ module,
+ module_to_name,
+ members,
+ documented,
+ exclude_symbols=(),
+ catch_all=False):
+ """Creates a new Library.
+
+ Args:
+ title: A human-readable title for the library.
+ module: Module to pull high level docstring from (for table of contents,
+ list of Ops to document, etc.).
+ module_to_name: Dictionary mapping modules to short names.
+ members: Dictionary mapping member name to (fullname, member).
+ documented: Set of documented names to update.
+ exclude_symbols: A list of specific symbols to exclude.
+ """
+ self._title = title
+ self._module = module
+ self._module_to_name = module_to_name
+ self._members = dict(members) # Copy since we mutate it below
+ self._exclude_symbols = frozenset(exclude_symbols)
+ documented.update(exclude_symbols)
+ self._documented = documented
+ self._mentioned = set()
+
+ @property
+ def title(self):
+ """The human-readable title for this library."""
+ return self._title
+
+ @property
+ def mentioned(self):
+ """Set of names mentioned in this library."""
+ return self._mentioned
+
+ @property
+ def exclude_symbols(self):
+ """Set of excluded symbols."""
+ return self._exclude_symbols
+
+ def _should_include_member(self, name, member):
+ """Returns True if this member should be included in the document."""
+ # Always exclude symbols matching _always_drop_symbol_re.
+ if _always_drop_symbol_re.match(name):
+ return False
+ # Finally, exclude any specifically-excluded symbols.
+ if name in self._exclude_symbols:
+ return False
+ return True
+
+ def get_imported_modules(self, module):
+ """Returns the list of modules imported from `module`."""
+ for name, member in inspect.getmembers(module):
+ if inspect.ismodule(member):
+ yield name, member
+
+ def get_class_members(self, cls_name, cls):
+ """Returns the list of class members to document in `cls`.
+
+ This function filters the class member to ONLY return those
+ defined by the class. It drops the inherited ones.
+
+ Args:
+ cls_name: Qualified name of `cls`.
+ cls: An inspect object of type 'class'.
+
+ Yields:
+ name, member tuples.
+ """
+ for name, member in inspect.getmembers(cls):
+ # Only show methods and properties presently.
+ if not (inspect.ismethod(member) or isinstance(member, property)):
+ continue
+ if ((inspect.ismethod(member) and member.__name__ == "__init__")
+ or self._should_include_member(name, member)):
+ yield name, ("%s.%s" % (cls_name, name), member)
+
+ def _generate_signature_for_function(self, func):
+ """Given a function, returns a string representing its args."""
+ args_list = []
+ argspec = inspect.getargspec(func)
+ first_arg_with_default = (
+ len(argspec.args or []) - len(argspec.defaults or []))
+ for arg in argspec.args[:first_arg_with_default]:
+ if arg == "self":
+ # Python documentation typically skips `self` when printing method
+ # signatures.
+ continue
+ args_list.append(arg)
+ if argspec.defaults:
+ for arg, default in zip(
+ argspec.args[first_arg_with_default:], argspec.defaults):
+ args_list.append("%s=%r" % (arg, default))
+ if argspec.varargs:
+ args_list.append("*" + argspec.varargs)
+ if argspec.keywords:
+ args_list.append("**" + argspec.keywords)
+ return "(" + ", ".join(args_list) + ")"
+
+ def _remove_docstring_indent(self, docstring):
+ """Remove indenting.
+
+ We follow Python's convention and remove the minimum indent of the lines
+ after the first, see:
+ https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
+ preserving relative indentation.
+
+ Args:
+ docstring: A docstring.
+
+ Returns:
+ A list of strings, one per line, with the minimum indent stripped.
+ """
+ docstring = docstring or ""
+ lines = docstring.strip().split("\n")
+
+ min_indent = len(docstring)
+ for l in lines[1:]:
+ l = l.rstrip()
+ if l:
+ i = 0
+ while i < len(l) and l[i] == " ":
+ i += 1
+ if i < min_indent: min_indent = i
+ for i in range(1, len(lines)):
+ l = lines[i].rstrip()
+ if len(l) >= min_indent:
+ l = l[min_indent:]
+ lines[i] = l
+ return lines
+
+ def _print_formatted_docstring(self, docstring, f):
+ """Formats the given `docstring` as Markdown and prints it to `f`."""
+ lines = self._remove_docstring_indent(docstring)
+
+ # Output the lines, identifying "Args" and other section blocks.
+ i = 0
+
+ def _at_start_of_section():
+ """Returns the header if lines[i] is at start of a docstring section."""
+ l = lines[i]
+ match = _section_re.match(l)
+ if match and i + 1 < len(
+ lines) and lines[i + 1].startswith(" "):
+ return match.group(1)
+ else:
+ return None
+
+ while i < len(lines):
+ l = lines[i]
+
+ section_header = _at_start_of_section()
+ if section_header:
+ if i == 0 or lines[i-1]:
+ print >>f, ""
+ # Use at least H4 to keep these out of the TOC.
+ print >>f, "##### " + section_header + ":"
+ print >>f, ""
+ i += 1
+ outputting_list = False
+ while i < len(lines):
+ l = lines[i]
+ # A new section header terminates the section.
+ if _at_start_of_section():
+ break
+ match = _arg_re.match(l)
+ if match:
+ if not outputting_list:
+ # We need to start a list. In Markdown, a blank line needs to
+ # precede a list.
+ print >>f, ""
+ outputting_list = True
+ suffix = l[len(match.group()):].lstrip()
+ print >>f, "* <b>" + match.group(1) + "</b>: " + suffix
+ else:
+ # For lines that don't start with _arg_re, continue the list if it
+ # has enough indentation.
+ outputting_list &= l.startswith(" ")
+ print >>f, l
+ i += 1
+ else:
+ print >>f, l
+ i += 1
+
+ def _print_function(self, f, prefix, fullname, func):
+ """Prints the given function to `f`."""
+ heading = prefix + " " + fullname
+ if not isinstance(func, property):
+ heading += self._generate_signature_for_function(func)
+ heading += " {#%s}" % _get_anchor(self._module_to_name, fullname)
+ print >>f, heading
+ print >>f, ""
+ self._print_formatted_docstring(inspect.getdoc(func), f)
+ print >>f, ""
+
+ def _write_member_markdown_to_file(self, f, name, member):
+ """Print `member` to `f`."""
+ if inspect.isfunction(member):
+ print >>f, "- - -"
+ print >>f, ""
+ self._print_function(f, "###", name, member)
+ print >>f, ""
+ elif inspect.ismethod(member):
+ print >>f, "- - -"
+ print >>f, ""
+ self._print_function(f, "####", name, member)
+ print >>f, ""
+ elif isinstance(member, property):
+ print >>f, "- - -"
+ print >>f, ""
+ self._print_function(f, "####", name, member)
+ elif inspect.isclass(member):
+ print >>f, "- - -"
+ print >>f, ""
+ print >>f, "### class %s {#%s}" % (
+ name, _get_anchor(self._module_to_name, name))
+ print >>f, ""
+ self._write_class_markdown_to_file(f, name, member)
+ print >>f, ""
+ else:
+ raise RuntimeError("Member %s has unknown type %s" % (name, type(member)))
+
+ def _write_docstring_markdown_to_file(self, f, docstring, members, imports):
+ for l in self._remove_docstring_indent(docstring):
+ if l.startswith(_member_mark):
+ name = l[len(_member_mark):].strip(" \t")
+ if name in members:
+ self._documented.add(name)
+ self._mentioned.add(name)
+ self._write_member_markdown_to_file(f, *members[name])
+ del members[name]
+ elif name in imports:
+ self._write_module_markdown_to_file(f, imports[name])
+ else:
+ raise ValueError("%s: unknown member `%s`" % (self._title, name))
+ else:
+ print >>f, l
+
+ def _write_class_markdown_to_file(self, f, name, cls):
+ """Write the class doc to 'f'.
+
+ Args:
+ f: File to write to.
+ prefix: Prefix for names.
+ cls: class object.
+ name: name to use.
+ """
+ # Build the list of class methods to document.
+ methods = dict(self.get_class_members(name, cls))
+ # Used later to check if any methods were called out in the class
+ # docstring.
+ num_methods = len(methods)
+ self._write_docstring_markdown_to_file(f, inspect.getdoc(cls), methods, {})
+
+ # If some methods were not described, describe them now if they are
+ # defined by the class itself (not inherited). If NO methods were
+ # described, describe all methods.
+ #
+ # TODO(mdevin): when all methods have been categorized make it an error
+ # if some methods are not categorized.
+ any_method_called_out = (len(methods) != num_methods)
+ if any_method_called_out:
+ other_methods = {n: m for n, m in methods.iteritems()
+ if n in cls.__dict__}
+ if other_methods:
+ print >>f, "\n#### Other Methods"
+ else:
+ other_methods = methods
+ for name in sorted(other_methods):
+ self._write_member_markdown_to_file(f, *other_methods[name])
+
+ def _write_module_markdown_to_file(self, f, module):
+ imports = dict(self.get_imported_modules(module))
+ self._write_docstring_markdown_to_file(f, inspect.getdoc(module),
+ self._members, imports)
+
+ def write_markdown_to_file(self, f):
+ """Prints this library to file `f`.
+
+ Args:
+ f: File to write to.
+
+ Returns:
+ Dictionary of documented members.
+ """
+ print >>f, "<!-- This file is machine generated: DO NOT EDIT! -->"
+ print >>f, ""
+ # TODO(mdevin): Do not insert these. Let the doc writer put them in
+ # the module docstring explicitly.
+ print >>f, "#", self._title
+ print >>f, "[TOC]"
+ print >>f, ""
+ if self._module is not None:
+ self._write_module_markdown_to_file(f, self._module)
+
+ def write_other_members(self, f, catch_all=False):
+ """Writes the leftover members to `f`.
+
+ Args:
+ f: File to write to.
+ catch_all: If true, document all missing symbols from any module.
+ Otherwise, document missing symbols from just this module.
+ """
+ if catch_all:
+ names = self._members.iteritems()
+ else:
+ names = inspect.getmembers(self._module)
+ leftovers = []
+ for name, _ in names:
+ if name in self._members and name not in self._documented:
+ leftovers.append(name)
+ if leftovers:
+ print "%s: undocumented members: %d" % (self._title, len(leftovers))
+ print >>f, "\n## Other Functions and Classes"
+ for name in sorted(leftovers):
+ print " %s" % name
+ self._documented.add(name)
+ self._mentioned.add(name)
+ self._write_member_markdown_to_file(f, *self._members[name])
+
+ def assert_no_leftovers(self):
+ """Generate an error if there are leftover members."""
+ leftovers = []
+ for name in self._members.iterkeys():
+ if name in self._members and name not in self._documented:
+ leftovers.append(name)
+ if leftovers:
+ raise RuntimeError("%s: undocumented members: %s" %
+ (self._title, ", ".join(leftovers)))
+
+
+def write_libraries(dir, libraries):
+ """Write a list of libraries to disk.
+
+ Args:
+ dir: Output directory.
+ libraries: List of (filename, library) pairs.
+ """
+ files = [open(os.path.join(dir, k), "w") for k, _ in libraries]
+ # Document mentioned symbols for all libraries
+ for f, (_, v) in zip(files, libraries):
+ v.write_markdown_to_file(f)
+ # Document symbols that no library mentioned. We do this after writing
+ # out all libraries so that earlier libraries know what later libraries
+ # documented.
+ for f, (_, v) in zip(files, libraries):
+ v.write_other_members(f)
+ f.close()
diff --git a/tensorflow/python/framework/errors.py b/tensorflow/python/framework/errors.py
new file mode 100644
index 0000000000..fe8f107cec
--- /dev/null
+++ b/tensorflow/python/framework/errors.py
@@ -0,0 +1,410 @@
+"""Exception types for TensorFlow errors."""
+import traceback
+import warnings
+
+from tensorflow.core.lib.core import error_codes_pb2
+
+
+class OpError(Exception):
+ """A generic error that is raised when TensorFlow execution fails.
+
+ Whenever possible, the session will raise a more specific subclass
+ of `OpError` from the `tf.errors` module.
+
+ @@op
+ @@node_def
+ """
+
+ def __init__(self, node_def, op, message, error_code):
+ """Creates a new OpError indicating that a particular op failed.
+
+ Args:
+ node_def: The graph_pb2.NodeDef proto representing the op that failed.
+ op: The ops.Operation that failed, if known; otherwise None.
+ message: The message string describing the failure.
+ error_code: The error_codes_pb2.Code describing the error.
+ """
+ super(OpError, self).__init__()
+ self._message = message
+ self._node_def = node_def
+ self._op = op
+ self._error_code = error_code
+
+ @property
+ def message(self):
+ """The error message that describes the error."""
+ return self._message
+
+ @property
+ def op(self):
+ """The operation that failed, if known.
+
+ *N.B.* If the failed op was synthesized at runtime, e.g. a `Send`
+ or `Recv` op, there will be no corresponding
+ [`Operation`](framework.md#Operation) object. In that case, this
+ will return `None`, and you should instead use the
+ [`node_def`](OpError.node_def) to discover information about the op.
+
+ Returns:
+ The `Operation` that failed, or None.
+ """
+ return self._op
+
+ @property
+ def error_code(self):
+ """The integer error code that describes the error."""
+ return self._error_code
+
+ @property
+ def node_def(self):
+ """The `NodeDef` proto representing the op that failed."""
+ return self._node_def
+
+ def __str__(self):
+ if self._op is not None:
+ output = ["%s\nCaused by op %r, defined at:\n"
+ % (self.message, self._op.name,)]
+ curr_traceback_list = traceback.format_list(self._op.traceback)
+ output.extend(curr_traceback_list)
+ original_op = self._op._original_op
+ while original_op is not None:
+ output.append(
+ "\n...which was originally created as op %r, defined at:\n"
+ % (original_op.name,))
+ prev_traceback_list = curr_traceback_list
+ curr_traceback_list = traceback.format_list(original_op.traceback)
+
+ # Attempt to elide large common subsequences of the subsequent
+ # stack traces.
+ #
+ # TODO(mrry): Consider computing the actual longest common subsequence.
+ is_eliding = False
+ elide_count = 0
+ last_elided_line = None
+ for line, line_in_prev in zip(curr_traceback_list, prev_traceback_list):
+ if line == line_in_prev:
+ if is_eliding:
+ elide_count += 1
+ last_elided_line = line
+ else:
+ output.append(line)
+ is_eliding = True
+ elide_count = 0
+ else:
+ if is_eliding:
+ if elide_count > 0:
+ output.extend(
+ ["[elided %d identical lines from previous traceback]\n"
+ % (elide_count - 1,), last_elided_line])
+ is_eliding = False
+ output.extend(line)
+
+ original_op = original_op._original_op
+ return ''.join(output)
+ else:
+ return self.message
+
+
+OK = error_codes_pb2.OK
+CANCELLED = error_codes_pb2.CANCELLED
+UNKNOWN = error_codes_pb2.UNKNOWN
+INVALID_ARGUMENT = error_codes_pb2.INVALID_ARGUMENT
+DEADLINE_EXCEEDED = error_codes_pb2.DEADLINE_EXCEEDED
+NOT_FOUND = error_codes_pb2.NOT_FOUND
+ALREADY_EXISTS = error_codes_pb2.ALREADY_EXISTS
+PERMISSION_DENIED = error_codes_pb2.PERMISSION_DENIED
+UNAUTHENTICATED = error_codes_pb2.UNAUTHENTICATED
+RESOURCE_EXHAUSTED = error_codes_pb2.RESOURCE_EXHAUSTED
+FAILED_PRECONDITION = error_codes_pb2.FAILED_PRECONDITION
+ABORTED = error_codes_pb2.ABORTED
+OUT_OF_RANGE = error_codes_pb2.OUT_OF_RANGE
+UNIMPLEMENTED = error_codes_pb2.UNIMPLEMENTED
+INTERNAL = error_codes_pb2.INTERNAL
+UNAVAILABLE = error_codes_pb2.UNAVAILABLE
+DATA_LOSS = error_codes_pb2.DATA_LOSS
+
+
+class CancelledError(OpError):
+ """Raised when an operation or step is cancelled.
+
+ For example, a long-running operation (e.g.
+ [`queue.enqueue()`](io_ops.md#QueueBase.enqueue) may be cancelled by
+ running another operation (e.g.
+ [`queue.close(cancel_pending_enqueues=True)`](io_ops.md#QueueBase.close),
+ or by [closing the session](client.md#Session.close). A step that is
+ running such a long-running operation will fail by raising `CancelledError`.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `CancelledError`."""
+ super(CancelledError, self).__init__(node_def, op, message, CANCELLED)
+
+
+class UnknownError(OpError):
+ """Unknown error.
+
+ An example of where this error may be returned is if a Status value
+ received from another address space belongs to an error-space that
+ is not known to this address space. Also errors raised by APIs that
+ do not return enough error information may be converted to this
+ error.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message, error_code=UNKNOWN):
+ """Creates an `UnknownError`."""
+ super(UnknownError, self).__init__(node_def, op, message, error_code)
+
+
+class InvalidArgumentError(OpError):
+ """Raised when an operation receives an invalid argument.
+
+ This may occur, for example, if an operation is receives an input
+ tensor that has an invalid value or shape. For example, the
+ [`tf.matmul()`](math_ops.md#matmul) op will raise this error if it
+ receives an input that is not a matrix, and the
+ [`tf.reshape()`](array_ops.md#reshape) op will raise this error if
+ the new shape does not match the number of elements in the input
+ tensor.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `InvalidArgumentError`."""
+ super(InvalidArgumentError, self).__init__(node_def, op, message,
+ INVALID_ARGUMENT)
+
+
+class DeadlineExceededError(OpError):
+ """Raised when a deadline expires before an operation could complete.
+
+ This exception is not currently used.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `DeadlineExceededError`."""
+ super(DeadlineExceededError, self).__init__(node_def, op, message,
+ DEADLINE_EXCEEDED)
+
+
+class NotFoundError(OpError):
+ """Raised when a requested entity (e.g., a file or directory) was not found.
+
+ For example, running the
+ [`tf.WholeFileReader.read()`](io_ops.md#WholeFileReader) operation
+ could raise `NotFoundError` if it receives the name of a file that
+ does not exist.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `NotFoundError`."""
+ super(NotFoundError, self).__init__(node_def, op, message, NOT_FOUND)
+
+
+class AlreadyExistsError(OpError):
+ """Raised when an entity that we attempted to create already exists.
+
+ For example, running an operation that saves a file
+ (e.g. [`tf.train.Saver.save()`](train.md#Saver.save)) could
+ potentially raise this exception if an explicit filename for an
+ existing file was passed.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `AlreadyExistsError`."""
+ super(AlreadyExistsError, self).__init__(node_def, op, message,
+ ALREADY_EXISTS)
+
+
+class PermissionDeniedError(OpError):
+ """Raised when the caller does not have permission to run an operation.
+
+ For example, running the
+ [`tf.WholeFileReader.read()`](io_ops.md#WholeFileReader) operation
+ could raise `PermissionDeniedError` if it receives the name of a
+ file for which the user does not have the read file permission.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `PermissionDeniedError`."""
+ super(PermissionDeniedError, self).__init__(node_def, op, message,
+ PERMISSION_DENIED)
+
+
+class UnauthenticatedError(OpError):
+ """The request does not have valid authentication credentials.
+
+ This exception is not currently used.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `UnauthenticatedError`."""
+ super(UnauthenticatedError, self).__init__(node_def, op, message,
+ UNAUTHENTICATED)
+
+
+class ResourceExhaustedError(OpError):
+ """Some resource has been exhausted.
+
+ For example, this error might be raised if a per-user quota is
+ exhausted, or perhaps the entire file system is out of space.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `ResourceExhaustedError`."""
+ super(ResourceExhaustedError, self).__init__(node_def, op, message,
+ RESOURCE_EXHAUSTED)
+
+
+class FailedPreconditionError(OpError):
+ """Operation was rejected because the system is not in a state to execute it.
+
+ This exception is most commonly raised when running an operation
+ that reads a [`tf.Variable`](state_ops.md#Variable) before it has
+ been initialized.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `FailedPreconditionError`."""
+ super(FailedPreconditionError, self).__init__(node_def, op, message,
+ FAILED_PRECONDITION)
+
+
+class AbortedError(OpError):
+ """The operation was aborted, typically due to a concurrent action.
+
+ For example, running a [`queue.enqueue()`](io_ops.md#QueueBase.enqueue)
+ operation may raise `AbortedError` if a
+ [`queue.close()`](io_ops.md@QueueBase.close) operation previously ran.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `AbortedError`."""
+ super(AbortedError, self).__init__(node_def, op, message, ABORTED)
+
+
+class OutOfRangeError(OpError):
+ """Raised when an operation executed past the valid range.
+
+ This exception is raised in "end-of-file" conditions, such as when a
+ [`queue.dequeue()`](io_ops.md#QueueBase.dequeue) operation is
+ blocked on an empty queue, and a
+ [`queue.close()`](io_ops.md#QueueBase.close) operation executes.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `OutOfRangeError`."""
+ super(OutOfRangeError, self).__init__(node_def, op, message,
+ OUT_OF_RANGE)
+
+
+class UnimplementedError(OpError):
+ """Raised when an operation has not been implemented.
+
+ Some operations may raise this error when passed otherwise-valid
+ arguments that it does not currently support. For example, running
+ the [`tf.nn.max_pool()`](nn.md#max_pool) operation would raise this
+ error if pooling was requested on the batch dimension, because this
+ is not yet supported.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `UnimplementedError`."""
+ super(UnimplementedError, self).__init__(node_def, op, message,
+ UNIMPLEMENTED)
+
+
+class InternalError(OpError):
+ """Raised when the system experiences an internal error.
+
+ This exception is raised when some invariant expected by the runtime
+ has been broken. Catching this exception is not recommended.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `InternalError`."""
+ super(InternalError, self).__init__(node_def, op, message, INTERNAL)
+
+
+class UnavailableError(OpError):
+ """Raised when the runtime is currently unavailable.
+
+ This exception is not currently used.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates an `UnavailableError`."""
+ super(UnavailableError, self).__init__(node_def, op, message,
+ UNAVAILABLE)
+
+
+class DataLossError(OpError):
+ """Raised when unrecoverable data loss or corruption is encountered.
+
+ For example, this may be raised by running a
+ [`tf.WholeFileReader.read()`](io_ops.md#WholeFileReader) operation,
+ if the file is truncated while it is being read.
+
+ @@__init__
+ """
+
+ def __init__(self, node_def, op, message):
+ """Creates a `DataLossError`."""
+ super(DataLossError, self).__init__(node_def, op, message, DATA_LOSS)
+
+
+_CODE_TO_EXCEPTION_CLASS = {
+ CANCELLED: CancelledError,
+ UNKNOWN: UnknownError,
+ INVALID_ARGUMENT: InvalidArgumentError,
+ DEADLINE_EXCEEDED: DeadlineExceededError,
+ NOT_FOUND: NotFoundError,
+ ALREADY_EXISTS: AlreadyExistsError,
+ PERMISSION_DENIED: PermissionDeniedError,
+ UNAUTHENTICATED: UnauthenticatedError,
+ RESOURCE_EXHAUSTED: ResourceExhaustedError,
+ FAILED_PRECONDITION: FailedPreconditionError,
+ ABORTED: AbortedError,
+ OUT_OF_RANGE: OutOfRangeError,
+ UNIMPLEMENTED: UnimplementedError,
+ INTERNAL: InternalError,
+ UNAVAILABLE: UnavailableError,
+ DATA_LOSS: DataLossError,
+}
+
+
+def _make_specific_exception(node_def, op, message, error_code):
+ try:
+ exc_type = _CODE_TO_EXCEPTION_CLASS[error_code]
+ return exc_type(node_def, op, message)
+ except KeyError:
+ warnings.warn("Unknown error code: %d" % error_code)
+ return UnknownError(node_def, op, message, error_code)
diff --git a/tensorflow/python/framework/errors_test.py b/tensorflow/python/framework/errors_test.py
new file mode 100644
index 0000000000..ab59a729f6
--- /dev/null
+++ b/tensorflow/python/framework/errors_test.py
@@ -0,0 +1,63 @@
+"""Tests for tensorflow.python.framework.errors."""
+import tensorflow.python.platform
+
+import warnings
+
+import tensorflow as tf
+
+from tensorflow.core.lib.core import error_codes_pb2
+
+class ErrorsTest(tf.test.TestCase):
+
+ def testUniqueClassForEachErrorCode(self):
+ for error_code, exc_type in [
+ (tf.errors.CANCELLED, tf.errors.CancelledError),
+ (tf.errors.UNKNOWN, tf.errors.UnknownError),
+ (tf.errors.INVALID_ARGUMENT, tf.errors.InvalidArgumentError),
+ (tf.errors.DEADLINE_EXCEEDED, tf.errors.DeadlineExceededError),
+ (tf.errors.NOT_FOUND, tf.errors.NotFoundError),
+ (tf.errors.ALREADY_EXISTS, tf.errors.AlreadyExistsError),
+ (tf.errors.PERMISSION_DENIED, tf.errors.PermissionDeniedError),
+ (tf.errors.UNAUTHENTICATED, tf.errors.UnauthenticatedError),
+ (tf.errors.RESOURCE_EXHAUSTED, tf.errors.ResourceExhaustedError),
+ (tf.errors.FAILED_PRECONDITION, tf.errors.FailedPreconditionError),
+ (tf.errors.ABORTED, tf.errors.AbortedError),
+ (tf.errors.OUT_OF_RANGE, tf.errors.OutOfRangeError),
+ (tf.errors.UNIMPLEMENTED, tf.errors.UnimplementedError),
+ (tf.errors.INTERNAL, tf.errors.InternalError),
+ (tf.errors.UNAVAILABLE, tf.errors.UnavailableError),
+ (tf.errors.DATA_LOSS, tf.errors.DataLossError),
+ ]:
+ # pylint: disable=protected-access
+ self.assertTrue(isinstance(
+ tf.errors._make_specific_exception(None, None, None, error_code),
+ exc_type))
+ # pylint: enable=protected-access
+
+ def testKnownErrorClassForEachErrorCodeInProto(self):
+ for error_code in error_codes_pb2.Code.values():
+ # pylint: disable=line-too-long
+ if error_code in (error_codes_pb2.OK,
+ error_codes_pb2.DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_):
+ continue
+ # pylint: enable=line-too-long
+ with warnings.catch_warnings(record=True) as w:
+ # pylint: disable=protected-access
+ exc = tf.errors._make_specific_exception(None, None, None, error_code)
+ # pylint: enable=protected-access
+ self.assertEqual(0, len(w)) # No warning is raised.
+ self.assertTrue(isinstance(exc, tf.errors.OpError))
+ self.assertTrue(tf.errors.OpError in exc.__class__.__bases__)
+
+ def testUnknownErrorCodeCausesWarning(self):
+ with warnings.catch_warnings(record=True) as w:
+ # pylint: disable=protected-access
+ exc = tf.errors._make_specific_exception(None, None, None, 37)
+ # pylint: enable=protected-access
+ self.assertEqual(1, len(w))
+ self.assertTrue("Unknown error code: 37" in str(w[0].message))
+ self.assertTrue(isinstance(exc, tf.errors.OpError))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py
new file mode 100644
index 0000000000..e317cfda8d
--- /dev/null
+++ b/tensorflow/python/framework/framework_lib.py
@@ -0,0 +1,70 @@
+# pylint: disable=wildcard-import,unused-import,g-bad-import-order,line-too-long
+"""Import names from the framework library.
+
+## Core graph data structures
+
+@@Graph
+@@Operation
+@@Tensor
+
+## Tensor types
+
+@@DType
+@@as_dtype
+
+## Utility functions
+
+@@device
+@@name_scope
+@@control_dependencies
+@@convert_to_tensor
+@@get_default_graph
+@@import_graph_def
+
+## Graph collections
+
+@@add_to_collection
+@@get_collection
+@@GraphKeys
+
+## Defining new operations
+
+@@RegisterGradient
+@@NoGradient
+@@RegisterShape
+@@TensorShape
+@@Dimension
+@@op_scope
+@@get_seed
+"""
+
+# Classes used when building a Graph.
+from tensorflow.python.framework.ops import Graph
+from tensorflow.python.framework.ops import Operation
+from tensorflow.python.framework.ops import Tensor
+from tensorflow.python.framework.ops import SparseTensor
+from tensorflow.python.framework.ops import SparseTensorValue
+from tensorflow.python.framework.ops import IndexedSlices
+
+# Utilities used when building a Graph.
+from tensorflow.python.framework.ops import device
+from tensorflow.python.framework.ops import name_scope
+from tensorflow.python.framework.ops import op_scope
+from tensorflow.python.framework.ops import control_dependencies
+from tensorflow.python.framework.ops import get_default_graph
+from tensorflow.python.framework.ops import GraphKeys
+from tensorflow.python.framework.ops import add_to_collection
+from tensorflow.python.framework.ops import get_collection
+from tensorflow.python.framework.ops import convert_to_tensor
+from tensorflow.python.framework.random_seed import get_seed
+from tensorflow.python.framework.random_seed import set_random_seed
+from tensorflow.python.framework.importer import import_graph_def
+
+# Needed when you defined a new Op in C++.
+from tensorflow.python.framework.ops import RegisterGradient
+from tensorflow.python.framework.ops import NoGradient
+from tensorflow.python.framework.ops import RegisterShape
+from tensorflow.python.framework.tensor_shape import Dimension
+from tensorflow.python.framework.tensor_shape import TensorShape
+
+from tensorflow.python.framework.types import *
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
new file mode 100644
index 0000000000..a726d880e7
--- /dev/null
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -0,0 +1,114 @@
+"""Updates generated docs from Python doc comments."""
+
+import os.path
+
+import tensorflow.python.platform
+import sys
+import tensorflow as tf
+
+from tensorflow.python.framework import docs
+from tensorflow.python.framework import framework_lib
+from tensorflow.python.client import client_lib
+
+
+tf.flags.DEFINE_string("out_dir", None,
+ "Directory to which docs should be written.")
+tf.flags.DEFINE_boolean("print_hidden_regex", False,
+ "Dump a regular expression matching any hidden symbol")
+FLAGS = tf.flags.FLAGS
+
+
+def get_module_to_name():
+ return {tf: 'tf',
+ tf.errors: 'tf.errors',
+ tf.image: 'tf.image',
+ tf.nn: 'tf.nn',
+ tf.train: 'tf.train',
+ tf.python_io: 'tf.python_io'}
+
+def all_libraries(module_to_name, members, documented):
+ # A list of (filename, docs.Library) pairs representing the individual files
+ # that we want to create.
+ def library(name, title, module=None, **args):
+ if module is None:
+ module = sys.modules["tensorflow.python.ops" +
+ ("" if name == "ops" else "." + name)]
+ return (name + ".md", docs.Library(title=title,
+ module_to_name=module_to_name,
+ members=members,
+ documented=documented,
+ module=module,
+ **args))
+ return [
+ # Splits of module 'tf'.
+ library("framework", "Building Graphs", framework_lib),
+ library("constant_op", "Constants, Sequences, and Random Values"),
+ library("state_ops", "Variables"),
+ library("array_ops", "Tensor Transformations",
+ exclude_symbols=["list_diff"]),
+ library("math_ops", "Math",
+ exclude_symbols=["sparse_matmul", "arg_min", "arg_max",
+ "lin_space", "sparse_segment_mean_grad"]),
+ library("control_flow_ops", "Control Flow"),
+ library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"]),
+ library("sparse_ops", "Sparse Tensors"),
+ library("io_ops", "Inputs and Readers",
+ exclude_symbols=["LookupTableBase", "HashTable",
+ "initialize_all_tables",
+ "string_to_hash_bucket"]),
+ library("python_io", "Data IO (Python functions)", tf.python_io),
+ library("nn", "Neural Network", tf.nn,
+ exclude_symbols=["deconv2d", "conv2d_backprop_input",
+ "conv2d_backprop_filter", "avg_pool_grad",
+ "max_pool_grad", "max_pool_grad_with_argmax",
+ "batch_norm_with_global_normalization_grad",
+ "lrn_grad", "relu6_grad", "softplus_grad",
+ "xw_plus_b", "relu_layer", "lrn",
+ "batch_norm_with_global_normalization",
+ "batch_norm_with_global_normalization_grad",
+ "all_candidate_sampler"]),
+ library('client', "Running Graphs", client_lib,
+ exclude_symbols=["InteractiveSession"]),
+ library("train", "Training", tf.train,
+ exclude_symbols=["Feature", "Features", "BytesList", "FloatList",
+ "Int64List", "Example", "InferenceExample",
+ "RankingExample", "SequenceExample"]),
+ ]
+
+_hidden_symbols = ["Event", "Summary",
+ "HistogramProto", "ConfigProto", "NodeDef", "GraphDef",
+ "GPUOptions", "SessionInterface", "BaseSession"]
+
+def main(unused_argv):
+ if not FLAGS.out_dir:
+ tf.logging.error("out_dir not specified")
+ return -1
+
+ # Document libraries
+ documented = set()
+ module_to_name = get_module_to_name()
+ members = docs.collect_members(module_to_name)
+ libraries = all_libraries(module_to_name, members, documented)
+ docs.write_libraries(FLAGS.out_dir, libraries)
+
+ # Make it easy to search for hidden symbols
+ if FLAGS.print_hidden_regex:
+ hidden = set(_hidden_symbols)
+ for _, lib in libraries:
+ hidden.update(lib.exclude_symbols)
+ print r"hidden symbols regex = r'\b(%s)\b'" % "|".join(sorted(hidden))
+
+ # Verify that all symbols are mentioned in some library doc.
+ catch_all = docs.Library(title="Catch All", module=None,
+ exclude_symbols=_hidden_symbols,
+ module_to_name=module_to_name, members=members,
+ documented=documented)
+ catch_all.assert_no_leftovers()
+
+ # Generate index
+ with open(os.path.join(FLAGS.out_dir, "index.md"), "w") as f:
+ docs.Index(module_to_name, members, libraries).write_markdown_to_file(f)
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/tensorflow/python/framework/gen_docs_test.sh b/tensorflow/python/framework/gen_docs_test.sh
new file mode 100755
index 0000000000..fda214d93c
--- /dev/null
+++ b/tensorflow/python/framework/gen_docs_test.sh
@@ -0,0 +1,4 @@
+#!/bin/bash -eux
+DIR=$TEST_SRCDIR/tensorflow/python
+$DIR/gen_docs_combined --out_dir $TEST_TMPDIR
+echo "PASS"
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
new file mode 100644
index 0000000000..6ad2a1b009
--- /dev/null
+++ b/tensorflow/python/framework/importer.py
@@ -0,0 +1,303 @@
+"""A utility function for importing TensorFlow graphs."""
+import contextlib
+
+import tensorflow.python.platform
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import op_def_registry
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types as types_lib
+
+
+# TODO(josh11b): SWIG the code from node_def_util instead of duplicating
+# the logic here.
+def _GetNodeAttr(node_def, attr_name):
+ if attr_name not in node_def.attr:
+ raise ValueError('Expected one attr with name %r in %s.'
+ % (attr_name, str(node_def)))
+ return node_def.attr[attr_name]
+
+
+def _ArgToTypesNoRef(node_def, arg_def):
+ if arg_def.number_attr:
+ repeats = _GetNodeAttr(node_def, arg_def.number_attr).i
+ if arg_def.type_attr:
+ dtype = _GetNodeAttr(node_def, arg_def.type_attr).type
+ else:
+ assert arg_def.type != types_pb2.DT_INVALID
+ dtype = arg_def.type
+ return [dtype] * repeats
+ elif arg_def.type_attr:
+ return [_GetNodeAttr(node_def, arg_def.type_attr).type]
+ elif arg_def.type_list_attr:
+ return _GetNodeAttr(node_def, arg_def.type_list_attr).list.type
+ else:
+ assert arg_def.type != types_pb2.DT_INVALID
+ return [arg_def.type]
+
+
+def _SingleArgToTypes(node_def, arg_def):
+ types = _ArgToTypesNoRef(node_def, arg_def)
+ if arg_def.is_ref:
+ return [types_lib.as_dtype(dt).as_ref.as_datatype_enum for dt in types]
+ return types
+
+
+def _ArgsToTypes(node_def, arg_list):
+ types = []
+ for arg_def in arg_list:
+ types.extend(_SingleArgToTypes(node_def, arg_def))
+ return types
+
+
+def _InputTypes(node_def, op_dict):
+ op_def = op_dict[node_def.op]
+ return _ArgsToTypes(node_def, op_def.input_arg)
+
+
+def _OutputTypes(node_def, op_dict):
+ op_def = op_dict[node_def.op]
+ return _ArgsToTypes(node_def, op_def.output_arg)
+
+
+def _IsControlInput(input_name):
+ # Expected format: '^operation_name' (control input).
+ return input_name.startswith('^')
+
+
+def _ParseTensorName(tensor_name):
+ """Parses a tensor name into an operation name and output index.
+
+ This function will canonicalize tensor names as follows:
+
+ * "foo:0" -> ("foo", 0)
+ * "foo:7" -> ("foo", 7)
+ * "foo" -> ("foo", 0)
+ * "foo:bar:baz" -> ValueError
+
+ Args:
+ tensor_name: The name of a tensor.
+
+ Returns:
+ A tuple containing the operation name, and the output index.
+
+ Raises:
+ ValueError: If `tensor_name' cannot be interpreted as the name of a tensor.
+ """
+ components = tensor_name.split(':')
+ if len(components) == 2:
+ # Expected format: 'operation_name:output_index'.
+ try:
+ output_index = int(components[1])
+ except ValueError:
+ raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,))
+ return components[0], output_index
+ elif len(components) == 1:
+ # Expected format: 'operation_name' (implicit 0th output).
+ return components[0], 0
+ else:
+ raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,))
+
+
+def _CanonicalInputName(input_name):
+ if _IsControlInput(input_name):
+ return input_name
+ input_op_name, output_index = _ParseTensorName(input_name)
+ return '%s:%d' % (input_op_name, output_index)
+
+
+def _InvalidNodeMessage(node, message):
+ return 'graph_def is invalid at node %r: %s.' % (node.name, message)
+
+
+@contextlib.contextmanager
+def _MaybeDevice(device):
+ """Applies the given device only if device is not None or empty."""
+ if device:
+ with ops.device(device):
+ yield
+ else:
+ yield
+
+
+def import_graph_def(graph_def, input_map=None, return_elements=None,
+ name=None, op_dict=None):
+ """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.
+
+ This function provides a way to import a serialized TensorFlow
+ [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
+ protocol buffer, and extract individual objects in the `GraphDef` as
+ [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
+ [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
+ `GraphDef` proto.
+
+ Args:
+ graph_def: A `GraphDef` proto containing operations to be imported into
+ the default graph.
+ input_map: A dictionary mapping input names (as strings) in `graph_def`
+ to `Tensor` objects. The values of the named input tensors in the
+ imported graph will be re-mapped to the respective `Tensor` values.
+ return_elements: A list of strings containing operation names in
+ `graph_def` that will be returned as `Operation` objects; and/or
+ tensor names in `graph_def` that will be returned as `Tensor` objects.
+ name: (Optional.) A prefix that will be prepended to the names in
+ `graph_def`. Defaults to `"import"`.
+ op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
+ Must contain an `OpDef` proto for each op type named in `graph_def`.
+ If omitted, uses the `OpDef` protos registered in the global registry.
+
+ Returns:
+ A list of `Operation` and/or `Tensor` objects from the imported graph,
+ corresponding to the names in `return_elements'.
+
+ Raises:
+ TypeError: If `graph_def` is not a `GraphDef` proto,
+ `input_map' is not a dictionary mapping strings to `Tensor` objects,
+ or `return_elements` is not a list of strings.
+ ValueError: If `input_map`, or `return_elements` contains names that
+ do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
+ it refers to an unknown tensor).
+ """
+ # Type checks for inputs.
+ if not isinstance(graph_def, graph_pb2.GraphDef):
+ raise TypeError('graph_def must be a GraphDef proto.')
+ if input_map is None:
+ input_map = {}
+ else:
+ if not (isinstance(input_map, dict)
+ and all(isinstance(k, basestring) for k in input_map.keys())):
+ raise TypeError('input_map must be a dictionary mapping strings to '
+ 'Tensor objects.')
+ if (return_elements is not None
+ and not (isinstance(return_elements, (list, tuple))
+ and all(isinstance(x, basestring) for x in return_elements))):
+ raise TypeError('return_elements must be a list of strings.')
+
+ # Use a canonical representation for all tensor names.
+ input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
+ used_input_keys = set()
+
+ name_to_op = {}
+
+ if op_dict is None:
+ op_dict = op_def_registry.get_registered_ops()
+
+ with ops.op_scope(input_map.values(), name, 'import'):
+ g = ops.get_default_graph()
+
+ with ops.name_scope('_inputs'):
+ input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}
+
+ # NOTE(mrry): We do this in two passes, because there may be a cycle in
+ # `graph_def'.
+
+ # 1. Add operations without their inputs.
+ for node in graph_def.node:
+ output_types = _OutputTypes(node, op_dict)
+ with _MaybeDevice(node.device):
+ name_to_op[node.name] = g.create_op(
+ node.op, [], output_types, name=node.name, attrs=node.attr,
+ compute_shapes=False)
+
+ # 2. Add inputs to the operations.
+ for node in graph_def.node:
+ op = name_to_op[node.name]
+ input_types = _InputTypes(node, op_dict)
+
+ # NOTE(mrry): We cannot use zip here because control inputs do not appear
+ # in the list of input_types.
+ for i, input_name in enumerate(
+ [_CanonicalInputName(x) for x in node.input]):
+
+ if _IsControlInput(input_name):
+ # (a) Input is a control input that should be taken from an op
+ # in "graph_def".
+ try:
+ source_op = name_to_op[input_name[1:]]
+ except KeyError:
+ raise ValueError(
+ _InvalidNodeMessage(
+ node,
+ 'Control input %r not found in graph_def.' % (input_name,)))
+ # pylint: disable=protected-access
+ op._add_control_input(source_op)
+ # pylint: enable=protected-access
+
+ else:
+ try:
+ input_type = input_types[i]
+ except IndexError:
+ raise ValueError(_InvalidNodeMessage(
+ node, 'More inputs specified (%r) than the op expects.'
+ % (input_name,)))
+
+ if input_name in input_map:
+ # (b) Input should be replaced by a tensor from the caller.
+ source_tensor = input_map[input_name]
+ used_input_keys.add(input_name)
+
+ else:
+ # (c) Input should be taken from an op in `graph_def'.
+ operation_name, output_index = _ParseTensorName(input_name)
+ try:
+ source_op = name_to_op[operation_name]
+ source_tensor = source_op.values()[output_index]
+ except (KeyError, IndexError):
+ raise ValueError(
+ _InvalidNodeMessage(
+ node,
+ 'Input tensor %r not found in graph_def.'
+ % (input_name,)))
+
+ try:
+ # pylint: disable=protected-access
+ op._add_input(source_tensor, dtype=input_type)
+ # pylint: enable=protected-access
+ except TypeError as te:
+ raise ValueError(
+ _InvalidNodeMessage(node, 'Input tensor %r %s'
+ % (input_name, te.message)))
+
+ # pylint: disable=protected_access
+ if op._input_dtypes != input_types:
+ raise ValueError(
+ _InvalidNodeMessage(
+ node,
+ 'Input types mismatch (expected %r but got %r)'
+ % (", ".join(types_lib.as_dtype(x).name for x in input_types),
+ ", ".join(x.name for x in op._input_dtypes))))
+ # pylint: enable=protected_access
+
+ # Execute shape inference for this op.
+ # NOTE(mrry): If the graph contains a cycle, the full shape information
+ # may not be available for this op's inputs.
+ ops.set_shapes_for_outputs(op)
+
+ # Treat unused input mappings as an error, because they are likely to be
+ # due to a typo.
+ unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
+ if unused_input_keys:
+ raise ValueError(
+ 'Attempted to map inputs that were not found in graph_def: [%s]'
+ % ', '.join(unused_input_keys))
+
+ if return_elements is None:
+ return None
+ else:
+ ret = []
+ for name in return_elements:
+ if ':' in name:
+ try:
+ operation_name, output_index = _ParseTensorName(name)
+ ret.append(name_to_op[operation_name].outputs[output_index])
+ except (ValueError, KeyError, IndexError):
+ raise ValueError(
+ 'Requested return_element %r not found in graph_def.' % name)
+ else:
+ try:
+ ret.append(name_to_op[name])
+ except KeyError:
+ raise ValueError(
+ 'Requested return_element %r not found in graph_def.' % name)
+ return ret
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
new file mode 100644
index 0000000000..470092313a
--- /dev/null
+++ b/tensorflow/python/framework/importer_test.py
@@ -0,0 +1,546 @@
+"""Tests for tensorflow.python.framework.importer."""
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.python.framework import device
+from tensorflow.python.framework import op_def_registry
+
+
+_op_list = op_def_pb2.OpList()
+text_format.Merge("""
+ op {
+ name: 'None'
+ }
+ op {
+ name: 'Oi'
+ output_arg { name: 'a' type: DT_INT32 }
+ }
+ op {
+ name: 'Or'
+ output_arg { name: 'a' type: DT_INT32 is_ref: true }
+ }
+ op {
+ name: 'Of'
+ output_arg { name: 'a' type: DT_FLOAT }
+ }
+ op {
+ name: 'Ii'
+ input_arg { name: 'a' type: DT_INT32 }
+ }
+ op {
+ name: 'If'
+ input_arg { name: 'a' type: DT_FLOAT }
+ }
+ op {
+ name: 'Oii'
+ output_arg { name: 'a' type: DT_INT32 }
+ output_arg { name: 'b' type: DT_INT32 }
+ }
+ op {
+ name: 'Oif'
+ output_arg { name: 'a' type: DT_INT32 }
+ output_arg { name: 'b' type: DT_FLOAT }
+ }
+ op {
+ name: 'Iii'
+ input_arg { name: 'a' type: DT_INT32 }
+ input_arg { name: 'b' type: DT_INT32 }
+ }
+ op {
+ name: 'Iff'
+ input_arg { name: 'a' type: DT_FLOAT }
+ input_arg { name: 'b' type: DT_FLOAT }
+ }
+ op {
+ name: 'Iif'
+ input_arg { name: 'a' type: DT_INT32 }
+ input_arg { name: 'b' type: DT_FLOAT }
+ }
+ op {
+ name: 'Iri'
+ input_arg { name: 'a' type: DT_INT32 is_ref: true }
+ input_arg { name: 'b' type: DT_INT32 }
+ }
+ op {
+ name: 'In'
+ input_arg { name: 'a' number_attr: 'N' type_attr: 'T' }
+ attr { name: 'N' type: 'int' minimum: 1 }
+ attr { name: 'T' type: 'type' }
+ }
+ op {
+ name: 'Otl'
+ output_arg { name: 'a' type_list_attr: 't' }
+ attr { name: 'T' type: 'list(type)' minimum: 1 }
+ }
+ op {
+ name: 'Unary'
+ input_arg { name: 'a' type_attr: 'T' }
+ output_arg { name: 'b' type_attr: 'T' }
+ attr { name: 'T' type: 'type' }
+ }
+""", _op_list)
+op_def_registry.register_op_list(_op_list)
+# NOTE(mrry): Dummy shape registrations for ops used in the tests.
+for op_def in _op_list.op:
+ tf.RegisterShape(op_def.name)(None)
+
+class ImportGraphDefTest(tf.test.TestCase):
+
+ def _MakeGraphDef(self, text):
+ ret = tf.GraphDef()
+ text_format.Merge(text, ret)
+ return ret
+
+ def testBasic(self):
+ with tf.Graph().as_default():
+ a, b, c, d = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oif' }
+ node { name: 'B' op: 'Otl'
+ attr { key: 't'
+ value { list { type: DT_INT32 type: DT_FLOAT } } } }
+ node { name: 'C' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ input: 'A:0' input: 'B:0' }
+ node { name: 'D' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_FLOAT } }
+ input: 'A:1' input: 'B:1' }
+ """),
+ return_elements=['A', 'B', 'C', 'D'],
+ name='import')
+
+ # Assert that the import process creates distinct tensors.
+ self.assertNotEqual(a.outputs[0].name, a.outputs[1].name)
+ self.assertNotEqual(b.outputs[0].name, b.outputs[1].name)
+ self.assertNotEqual(a.outputs[0].name, b.outputs[0].name)
+ self.assertNotEqual(a.outputs[0].name, b.outputs[1].name)
+ self.assertNotEqual(a.outputs[1].name, b.outputs[0].name)
+ self.assertNotEqual(a.outputs[1].name, b.outputs[1].name)
+
+ # Assert that the ops are connected according to the GraphDef topology.
+ self.assertEqual(c.inputs[0], a.outputs[0])
+ self.assertEqual(c.inputs[1], b.outputs[0])
+ self.assertEqual(d.inputs[0], a.outputs[1])
+ self.assertEqual(d.inputs[1], b.outputs[1])
+
+ # Check the types of the returned ops and tensors.
+ self.assertEqual(a.type, 'Oif')
+ self.assertEqual(b.type, 'Otl')
+ self.assertEqual(c.type, 'In')
+ self.assertEqual(d.type, 'In')
+ self.assertEqual(a.outputs[0].dtype, tf.int32)
+ self.assertEqual(a.outputs[1].dtype, tf.float32)
+ self.assertEqual(b.outputs[0].dtype, tf.int32)
+ self.assertEqual(b.outputs[1].dtype, tf.float32)
+
+ # Check the names of the returned ops.
+ self.assertEqual(a.name, 'import/A')
+ self.assertEqual(b.name, 'import/B')
+ self.assertEqual(c.name, 'import/C')
+ self.assertEqual(d.name, 'import/D')
+
+ def testInputMap(self):
+ with tf.Graph().as_default():
+ feed_a_0 = tf.constant(0, dtype=tf.int32)
+ feed_b_1 = tf.constant(1, dtype=tf.int32)
+
+ a, b, c, d = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oii' }
+ node { name: 'B' op: 'Oii' }
+ node { name: 'C' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ input: 'A:0' input: 'B:0' }
+ node { name: 'D' op: 'In'
+ attr { key: 'N' value { i: 2 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ input: 'A:1' input: 'B:1' }
+ """),
+ input_map={'A:0': feed_a_0, 'B:1': feed_b_1},
+ return_elements=['A', 'B', 'C', 'D'])
+
+ self.assertEqual(c.inputs[0], feed_a_0)
+ self.assertEqual(c.inputs[1], b.outputs[0])
+ self.assertEqual(d.inputs[0], a.outputs[1])
+ self.assertEqual(d.inputs[1], feed_b_1)
+
+ def testImplicitZerothOutput(self):
+ with tf.Graph().as_default():
+ a, b = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oii' }
+ node { name: 'B' op: 'Ii' input: 'A' }
+ """),
+ return_elements=['A', 'B'])
+
+ self.assertEqual(b.inputs[0], a.outputs[0])
+
+ def testInputMapImplicitZerothOutput(self):
+ with tf.Graph().as_default():
+ feed_a_0 = tf.constant(0, dtype=tf.int32)
+ b, = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oii' }
+ node { name: 'B' op: 'Ii' input: 'A:0' }
+ """),
+ input_map={'A': feed_a_0},
+ return_elements=['B'])
+
+ self.assertEqual(b.inputs[0], feed_a_0)
+
+ def testWithControlDependency(self):
+ with tf.Graph().as_default():
+ a, b = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ node { name: 'B' op: 'None' input: '^A' }
+ """),
+ return_elements=['A', 'B'])
+
+ self.assertEqual(b.control_inputs, [a])
+
+ def testWithRefs(self):
+ with tf.Graph().as_default():
+ a, b, c, d = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Or' }
+ node { name: 'B' op: 'Oi' }
+ node { name: 'C' op: 'Iii' input: 'A:0' input: 'B:0' }
+ node { name: 'D' op: 'Iri' input: 'A:0' input: 'B:0' }
+ """),
+ return_elements=['A', 'B', 'C', 'D'])
+
+ self.assertEqual(c.inputs[0], a.outputs[0])
+ self.assertEqual(c.inputs[1], b.outputs[0])
+ self.assertEqual(d.inputs[0], a.outputs[0])
+ self.assertEqual(d.inputs[1], b.outputs[0])
+
+ self.assertEqual(a.outputs[0].dtype, tf.int32_ref)
+ self.assertEqual(c._input_dtypes, [tf.int32, tf.int32])
+ self.assertEqual(c.outputs, [])
+ self.assertEqual(d._input_dtypes,
+ [tf.int32_ref, tf.int32])
+ self.assertEqual(d.outputs, [])
+
+ def testCyclic(self):
+ with tf.Graph().as_default():
+ a, b = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Unary'
+ attr { key: 'T' value { type: DT_INT32 } } input: 'B:0' }
+ node { name: 'B' op: 'Unary'
+ attr { key: 'T' value { type: DT_INT32 } } input: 'A:0' }
+ """),
+ return_elements=['A', 'B'])
+
+ self.assertEqual(a.inputs[0], b.outputs[0])
+ self.assertEqual(b.inputs[0], a.outputs[0])
+
+ def testTypeMismatchInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'If' input: 'A:0' }
+ """))
+ self.assertTrue(
+ 'Cannot convert a tensor of type int32 to an input of type float' in
+ str(e.exception))
+
+ def testInvalidSignatureTooManyInputsInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'None' input: 'A:0' }
+ """))
+ self.assertTrue('More inputs specified (u\'A:0\') than the op expects' in
+ str(e.exception))
+
+ def testInvalidSignatureNotEnoughInputsInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'Iif' input: 'A:0' }
+ """))
+ self.assertTrue('Input types mismatch (expected \'int32, float32\' but '
+ 'got \'int32\')' in str(e.exception))
+
+ def testMissingInputOpInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'If' input: 'A:0' }
+ """))
+ self.assertTrue('Input tensor %r not found' % (u'A:0',) in
+ str(e.exception))
+
+ def testMissingInputOpInGraphDefButAppearsInInputMap(self):
+ with tf.Graph().as_default():
+ feed_a_0 = tf.constant(5.0)
+ b, = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'If' input: 'A:0' }
+ """),
+ input_map={'A:0': feed_a_0},
+ return_elements=['B'])
+ self.assertEqual(b.inputs[0], feed_a_0)
+
+ def testMissingInputTensorInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Of' }
+ node { name: 'B' op: 'If' input: 'A:1' }
+ """))
+ self.assertTrue('Input tensor %r not found' % (u'A:1',) in
+ str(e.exception))
+
+ def testMissingControlInputInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'None' input: '^A' }
+ """))
+ self.assertTrue('Control input %r not found' % (u'^A',) in
+ str(e.exception))
+
+ def testInvalidTensorNameOutputIndexInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'None' input: 'A:B' }
+ """))
+ self.assertEqual(
+ 'Cannot convert %r to a tensor name.' % (u'A:B',), str(e.exception))
+
+ def testInvalidTensorNameInGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'B' op: 'None' input: 'A:B:0' }
+ """))
+ self.assertEqual(
+ 'Cannot convert %r to a tensor name.' % (u'A:B:0',), str(e.exception))
+
+ def testMissingReturnOperation(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """),
+ return_elements=['B'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('B') in
+ str(e.exception))
+
+ def testMissingReturnTensor(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ """),
+ return_elements=['A:1'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('A:1') in
+ str(e.exception))
+
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ """),
+ return_elements=['B:0'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('B:0') in
+ str(e.exception))
+
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ """),
+ return_elements=['A:B:0'])
+ self.assertTrue('return_element %r not found in graph_def.' % ('A:B:0') in
+ str(e.exception))
+
+ def testMissingInputMap(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """),
+ input_map={'B:0': tf.constant(5.0)})
+ self.assertTrue('not found in graph_def: [B:0]' in str(e.exception))
+
+ def testInputMapTypeMismatch(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(ValueError) as e:
+ tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'Oi' }
+ node { name: 'B' op: 'Ii' input: 'A:0' }
+ """),
+ input_map={'A:0': tf.constant(5.0)})
+ self.assertTrue(
+ 'Cannot convert a tensor of type float32 to an input of type int32.'
+ in str(e.exception))
+
+ def testNoReturns(self):
+ with tf.Graph().as_default() as g:
+ ret = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """))
+ self.assertEqual(ret, None)
+
+ a = g.get_operation_by_name('import/A')
+ self.assertEqual(a.type, 'None')
+
+ def testOverrideNamePrefix(self):
+ with tf.Graph().as_default():
+ a, = tf.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'None' }
+ """),
+ return_elements=['A'], name='imported_graph')
+ self.assertEqual(a.name, 'imported_graph/A')
+
+ def testEmptyGraph(self):
+ with tf.Graph().as_default() as g:
+ init_version = g.version
+ tf.import_graph_def(self._MakeGraphDef(''))
+ self.assertEqual(init_version, g.version)
+
+ def testInvalidInputForGraphDef(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(TypeError) as e:
+ tf.import_graph_def('')
+ self.assertEqual(
+ 'graph_def must be a GraphDef proto.', str(e.exception))
+
+ def testInvalidInputForInputMap(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(TypeError) as e:
+ tf.import_graph_def(self._MakeGraphDef(''),
+ input_map=[tf.constant(5.0)])
+ self.assertEqual('input_map must be a dictionary mapping strings to '
+ 'Tensor objects.', str(e.exception))
+
+ def testInvalidInputForReturnOperations(self):
+ with tf.Graph().as_default():
+ with self.assertRaises(TypeError) as e:
+ tf.import_graph_def(self._MakeGraphDef(''), return_elements=[7])
+ self.assertEqual(
+ 'return_elements must be a list of strings.', str(e.exception))
+
+ def testWithExtensionAndAttr(self):
+ with tf.Graph().as_default() as g:
+ c = tf.constant(5.0, dtype=tf.float32, name='c')
+ tf.pack([c, c], name='pack')
+ gdef = g.as_graph_def()
+
+ with self.test_session():
+ pack, = tf.import_graph_def(gdef, return_elements=['pack'])
+ self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0])
+
+ def testWithDevice(self):
+ with tf.Graph().as_default() as g:
+ # No device.
+ a = tf.constant(3.0, name='a')
+
+ with tf.device('/cpu:0'):
+ b = tf.constant(4.0, name='b')
+ with tf.device('/job:worker'):
+ c = tf.constant(5.0, name='c')
+
+ gdef = g.as_graph_def()
+
+ with tf.Graph().as_default():
+ a2, b2, c2 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual(a.device, a2.device)
+ self.assertEqual(b.device, b2.device)
+ self.assertEqual(c.device, c2.device)
+
+ with tf.Graph().as_default():
+ with tf.device(device.merge_device('/task:0')):
+ a3, b3, c3 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual('/task:0', a3.device)
+ self.assertEqual('/task:0/device:CPU:0', b3.device) # canonicalized.
+ self.assertEqual(c.device + '/task:0', c3.device)
+
+ with tf.Graph().as_default():
+ with tf.device(device.merge_device('/job:ps')):
+ a4, b4, c4 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual('/job:ps', a4.device)
+ self.assertEqual('/job:ps/device:CPU:0', b4.device) # canonicalized.
+ self.assertEqual(c.device, c4.device) # worker overrides ps.
+
+ with tf.Graph().as_default():
+ with tf.device(device.merge_device('/gpu:0')):
+ a5, b5, c5 = tf.import_graph_def(
+ gdef, return_elements=['a', 'b', 'c'])
+ self.assertEqual('/device:GPU:0', a5.device)
+ self.assertEqual('/device:CPU:0', b5.device) # cpu overrides gpu.
+ self.assertEqual(c.device + '/device:GPU:0', c5.device)
+
+ def testGradient(self):
+ with tf.Graph().as_default() as g:
+ inputs = tf.placeholder(tf.float32, shape=[None, 100], name="input")
+ weights = tf.placeholder(tf.float32, shape=[100, 10], name="weights")
+ biases = tf.placeholder(tf.float32, shape=[10], name="biases")
+ activations = tf.nn.relu(tf.matmul(inputs, weights) + biases,
+ name="activations")
+ loss = tf.reduce_mean(activations, name="loss")
+ gdef = g.as_graph_def()
+
+ with tf.Graph().as_default() as g:
+ input_placeholder = tf.placeholder(tf.float32, shape=[32, 100])
+ weights_var = tf.Variable(tf.truncated_normal([100, 10]), name="weights")
+ biases_var = tf.Variable(tf.zeros(10), name="biases")
+ activations, loss = tf.import_graph_def(
+ gdef,
+ input_map={"input:0": input_placeholder,
+ "weights:0": weights_var,
+ "biases:0": biases_var},
+ return_elements=["activations:0", "loss:0"])
+ self.assertEqual([32, 10], activations.get_shape())
+ self.assertEqual([], loss.get_shape())
+ weights_grad, biases_grad = tf.gradients(loss, [weights_var, biases_var])
+ self.assertEqual([100, 10], weights_grad.get_shape())
+ self.assertEqual([10], biases_grad.get_shape())
+
+ def testLargeGraph(self):
+ with self.test_session():
+ # The default message byte limit is 64M. Ours is 2G with a warning at 512.
+ # Adding a 150M entries float32 tensor should blow through the warning,
+ # but not the hard limit.
+ input_shape = [150, 1024, 1024]
+ tensor_input = tf.np.random.rand(*input_shape).astype(tf.np.float32)
+ t = tf.constant(tensor_input, shape=input_shape)
+ g = tf.identity(t)
+ g.eval()
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/framework/op_def_registry.py b/tensorflow/python/framework/op_def_registry.py
new file mode 100644
index 0000000000..2ec8c94a10
--- /dev/null
+++ b/tensorflow/python/framework/op_def_registry.py
@@ -0,0 +1,23 @@
+"""Global registry for OpDefs."""
+
+from tensorflow.core.framework import op_def_pb2
+
+
+_registered_ops = {}
+
+
+def register_op_list(op_list):
+ """Register all the ops in an op_def_pb2.OpList."""
+ if not isinstance(op_list, op_def_pb2.OpList):
+ raise TypeError("%s is %s, not an op_def_pb2.OpList" %
+ (op_list, type(op_list)))
+ for op_def in op_list.op:
+ if op_def.name in _registered_ops:
+ assert _registered_ops[op_def.name] == op_def
+ else:
+ _registered_ops[op_def.name] = op_def
+
+
+def get_registered_ops():
+ """Returns a dictionary mapping names to OpDefs."""
+ return _registered_ops
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
new file mode 100644
index 0000000000..0b0442cea1
--- /dev/null
+++ b/tensorflow/python/framework/ops.py
@@ -0,0 +1,2985 @@
+"""Classes and functions used to construct graphs."""
+# pylint: disable=g-bad-name
+import collections
+import contextlib
+import copy
+import linecache
+import re
+import sys
+import threading
+import weakref
+
+import tensorflow.python.platform
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import registry
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+
+
+def _convert_stack(stack):
+ """Converts a stack extracted using _extract_stack() to a traceback stack.
+
+ Args:
+ stack: A list of n 4-tuples, (filename, lineno, name, frame_globals).
+
+ Returns:
+ A list of n 4-tuples (filename, lineno, name, code), where the code tuple
+ element is calculated from the corresponding elements of the input tuple.
+ """
+ ret = []
+ for filename, lineno, name, frame_globals in stack:
+ linecache.checkcache(filename)
+ line = linecache.getline(filename, lineno, frame_globals)
+ if line:
+ line = line.strip()
+ else:
+ line = None
+ ret.append((filename, lineno, name, line))
+ return ret
+
+
+# pylint: disable=line-too-long
+def _extract_stack():
+ """A lightweight re-implementation of traceback.extract_stack.
+
+ NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
+ each stack frame using linecache, which results in an abundance of stat()
+ calls. This implementation does not retrieve the code, and any consumer
+ should apply _convert_stack to the result to obtain a traceback that can
+ be formatted etc. using traceback methods.
+
+ Returns:
+ A list of 4-tuples (filename, lineno, name, frame_globals) corresponding to
+ the call stack of the current thread.
+ """
+ # pylint: enable=line-too-long
+ try:
+ raise ZeroDivisionError
+ except ZeroDivisionError:
+ f = sys.exc_info()[2].tb_frame.f_back
+ ret = []
+ while f is not None:
+ lineno = f.f_lineno
+ co = f.f_code
+ filename = co.co_filename
+ name = co.co_name
+ frame_globals = f.f_globals
+ ret.append((filename, lineno, name, frame_globals))
+ f = f.f_back
+ ret.reverse()
+ return ret
+
+
+class Tensor(object):
+ """Represents a value produced by an `Operation`.
+
+ A `Tensor` is a symbolic handle to one of the outputs of an
+ `Operation`. It does not hold the values of that operation's output,
+ but instead provides a means of computing those values in a
+ TensorFlow [`Session`](client.md#Session).
+
+ This class has two primary purposes:
+
+ 1. A `Tensor` can be passed as an input to another `Operation`.
+ This builds a dataflow connection between operations, which
+ enables TensorFlow to execute an entire `Graph` that represents a
+ large, multi-step computation.
+
+ 2. After the graph has been launched in a session, the value of the
+ `Tensor` can be computed by passing it to
+ [`Session.run()`](client.md#Session.run).
+ `t.eval()` is a shortcut for calling
+ `tf.get_default_session().run(t)`.
+
+ In the following example, `c`, `d`, and `e` are symbolic `Tensor`
+ objects, whereas `result` is a numpy array that stores a concrete
+ value:
+
+ ```python
+ # Build a dataflow graph.
+ c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
+ d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
+ e = tf.matmul(c, d)
+
+ # Construct a `Session` to execut the graph.
+ sess = tf.Session()
+
+ # Execute the graph and store the value that `e` represents in `result`.
+ result = sess.run(e)
+ ```
+
+ @@dtype
+ @@name
+ @@value_index
+ @@graph
+ @@op
+ @@consumers
+
+ @@eval
+
+ @@get_shape
+ @@set_shape
+
+ """
+
+ # List of Python operators that we allow to override.
+ OVERLOADABLE_OPERATORS = {
+ # Binary.
+ "__add__", "__radd__",
+ "__sub__", "__rsub__",
+ "__mul__", "__rmul__",
+ "__div__", "__rdiv__",
+ "__truediv__", "__rtruediv__",
+ "__mod__", "__rmod__",
+ "__lt__", "__le__",
+ "__gt__", "__ge__",
+ "__and__", "__rand__",
+ "__or__", "__ror__",
+ "__xor__", "__rxor__",
+ "__getitem__",
+ # Unary.
+ "__invert__",
+ "__neg__", "__abs__"}
+
+ def __init__(self, op, value_index, dtype):
+ """Creates a new `Tensor`.
+
+ Args:
+ op: An `Operation`. `Operation` that computes this tensor.
+ value_index: An `int`. Index of the operation's endpoint that produces
+ this tensor.
+ dtype: A `types.DType`. Type of data stored in this tensor.
+
+ Raises:
+ TypeError: If the op is not an `Operation`.
+ """
+ if not isinstance(op, Operation):
+ raise TypeError("op needs to be an Operation: %s" % op)
+ self._op = op
+ self._value_index = value_index
+ self._dtype = types.as_dtype(dtype)
+ self._shape = tensor_shape.unknown_shape()
+ # List of operations that use this Tensor as input. We maintain this list
+ # to easily navigate a computation graph.
+ self._consumers = []
+
+ @property
+ def op(self):
+ """The `Operation` that produces this tensor as an output."""
+ return self._op
+
+ @property
+ def dtype(self):
+ """The `DType` of elements in this tensor."""
+ return self._dtype
+
+ @property
+ def graph(self):
+ """The `Graph` that contains this tensor."""
+ return self._op.graph
+
+ @property
+ def name(self):
+ """The string name of this tensor."""
+ if not self._op.name:
+ raise ValueError("Operation was not named: %s" % self._op)
+ return "%s:%d" % (self._op.name, self._value_index)
+
+ @property
+ def device(self):
+ """The name of the device on which this tensor will be produced, or None."""
+ return self._op.device
+
+ def _shape_as_list(self):
+ if self._shape.ndims is not None:
+ return [dim.value for dim in self._shape.dims]
+ else:
+ return None
+
+ def get_shape(self):
+ """Returns the `TensorShape` that represents the shape of this tensor.
+
+ The shape is computed using shape inference functions that are
+ registered for each `Operation` type using `tf.RegisterShape`.
+ See [`TensorShape`](framework.md#TensorShape) for more details of what a shape
+ represents.
+
+ The inferred shape of a tensor is used to provide shape
+ information without having to launch the graph in a session. This
+ can be used for debugging, and providing early error messages. For
+ example:
+
+ ```python
+ c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+
+ print c.get_shape()
+ ==> TensorShape([Dimension(2), Dimension(3)])
+
+ d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]])
+
+ print d.get_shape()
+ ==> TensorShape([Dimension(4), Dimension(2)])
+
+ # Raises a ValueError, because `c` and `d` do not have compatible
+ # inner dimensions.
+ e = tf.matmul(c, d)
+
+ f = tf.matmul(c, d, transpose_a=True, transpose_b=True)
+
+ print f.get_shape()
+ ==> TensorShape([Dimension(3), Dimension(4)])
+ ```
+
+ In some cases, the inferred shape may have unknown dimensions. If
+ the caller has additional information about the values of these
+ dimensions, `Tensor.set_shape()` can be used to augment the
+ inferred shape.
+
+ Returns:
+ A `TensorShape` representing the shape of this tensor.
+ """
+ return self._shape
+
+ def set_shape(self, shape):
+ """Updates the shape of this tensor.
+
+ This method can be called multiple times, and will merge the given
+ `shape` with the current shape of this tensor. It can be used to
+ provide additional information about the shape of this tensor that
+ cannot be inferred from the graph alone. For example, this can be used
+ to provide additional information about the shapes of images:
+
+ ```python
+ _, image_data = tf.TFRecordReader(...).read(...)
+ image = tf.image.decode_png(image_data, channels=3)
+
+ # The height and width dimensions of `image` are data dependent, and
+ # cannot be computed without executing the op.
+ print image.get_shape()
+ ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)])
+
+ # We know that each image in this dataset is 28 x 28 pixels.
+ image.set_shape([28, 28, 3])
+ print image.get_shape()
+ ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
+ ```
+
+ Args:
+ shape: A `TensorShape` representing the shape of this tensor.
+
+ Raises:
+ ValueError: If `shape` is not compatible with the current shape of
+ this tensor.
+ """
+ self._shape = self._shape.merge_with(shape)
+
+ @property
+ def value_index(self):
+ """The index of this tensor in the outputs of its `Operation`."""
+ return self._value_index
+
+ def consumers(self):
+ """Returns a list of `Operation`s that consume this tensor.
+
+ Returns:
+ A list of `Operation`s.
+ """
+ return self._consumers
+
+ def _add_consumer(self, consumer):
+ """Add a consumer to this tensor.
+
+ Args:
+ consumer: an Operation.
+
+ Raises:
+ TypeError: if the consumer is not an Operation.
+ """
+ if not isinstance(consumer, Operation):
+ raise TypeError("Consumer must be an Operation: %s" % consumer)
+ self._consumers.append(consumer)
+
+ def _as_node_def_input(self):
+ """Return a value to use for the NodeDef "input" attribute.
+
+ The returned string can be used in a NodeDef "input" attribute
+ to indicate that the NodeDef uses this Tensor as input.
+
+ Raises:
+ ValueError: if this Tensor's Operation does not have a name.
+
+ Returns:
+ a string.
+ """
+ if not self._op.name:
+ raise ValueError("Operation was not named: %s" % self._op)
+ if self._value_index == 0:
+ return self._op.name
+ else:
+ return "%s:%d" % (self._op.name, self._value_index)
+
+ def __str__(self):
+ return "Tensor(\"%s\"%s%s%s)" % (
+ self.name,
+ (", shape=%s" % self.get_shape())
+ if self.get_shape().ndims is not None else "",
+ (", dtype=%s" % self._dtype.name) if self._dtype else "",
+ (", device=%s" % self.device) if self.device else "")
+
+ def __hash__(self):
+ # Necessary to support Python's collection membership operators
+ return id(self)
+
+ def __eq__(self, other):
+ # Necessary to support Python's collection membership operators
+ return id(self) == id(other)
+
+ # NOTE(mrry): This enables the Tensor's overloaded "right" binary
+ # operators to run when the left operand is an ndarray, because it
+ # accords the Tensor class higher priority than an ndarray, or a
+ # numpy matrix.
+ # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
+ # mechanism, which allows more control over how Tensors interact
+ # with ndarrays.
+ __array_priority__ = 100
+
+ @staticmethod
+ def _override_operator(operator, func):
+ """Overrides (string) operator on Tensors to call func.
+
+ Args:
+ operator: the string name of the operator to override.
+ func: the function that replaces the overriden operator.
+
+ Raises:
+ ValueError: If operator has already been overwritten,
+ or if operator is not allowed to be overwritten.
+ """
+ if getattr(Tensor, operator, None) is not None:
+ # check to see if this is a default method-wrapper which will be true
+ # for the comparison operators.
+ if not isinstance(getattr(Tensor, operator, None), type(all.__call__)):
+ raise ValueError("operator %s cannot be overwritten again." % operator)
+ if operator not in Tensor.OVERLOADABLE_OPERATORS:
+ raise ValueError("Overriding %s is disallowed" % operator)
+ setattr(Tensor, operator, func)
+
+ def __iter__(self):
+ """Dummy method to prevent iteration. Do not call.
+
+ NOTE(mrry): If we register __getitem__ as an overloaded operator,
+ Python will valiantly attempt to iterate over the Tensor from 0 to
+ infinity. Declaring this method prevents this unintended
+ behavior.
+
+ Raises:
+ TypeError: when invoked.
+ """
+ raise TypeError("'Tensor' object is not iterable")
+
+ def eval(self, feed_dict=None, session=None):
+ """Evaluates this tensor in a `Session`.
+
+ Calling this method will execute all preceding operations that
+ produce the inputs needed for the operation that produces this
+ tensor.
+
+ *N.B.* Before invoking `Tensor.eval()`, its graph must have been
+ launched in a session, and either a default session must be
+ available, or `session` must be specified explicitly.
+
+ Args:
+ feed_dict: A dictionary that maps `Tensor` objects to feed values.
+ See [`Session.run()`](client.md#Session.run) for a description of
+ the valid feed values.
+ session: (Optional.) The `Session` to be used to evaluate this tensor. If
+ none, the default session will be used.
+
+ Returns:
+ A numpy array corresponding to the value of this tensor.
+
+ """
+ return _eval_using_default_session(self, feed_dict, self.graph, session)
+
+
+def _TensorTensorConversionFunction(t, dtype=None, name=None):
+ _ = name
+ if dtype and not dtype.is_compatible_with(t.dtype):
+ raise ValueError(
+ "Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
+ % (dtype.name, t.dtype.name, str(t)))
+ return t
+
+
+_tensor_conversion_func_registry = {
+ 0: [(Tensor, _TensorTensorConversionFunction)]}
+
+
+def convert_to_tensor(value, dtype=None, name=None):
+ """Converts the given `value` to a `Tensor`.
+
+ This function converts Python objects of various types to `Tensor`
+ objects. It accepts `Tensor` objects, numpy arrays, Python lists,
+ and Python scalars. For example:
+
+ ```python
+ import numpy as np
+ array = np.random.rand((32, 100, 100))
+
+ def my_func(arg):
+ arg = tf.convert_to_tensor(arg, dtype=tf.float32)
+ return tf.matmul(arg, arg) + arg
+
+ # The following calls are equivalent.
+ value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]))
+ value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
+ value_3 = my_func(numpy.array([[1.0, 2.0], [3.0, 4.0]], dtype=numpy.float32))
+ ```
+
+ This function can be useful when composing a new operation in Python
+ (such as `my_func` in the example above). All standard Python op
+ constructors apply this function to each of their Tensor-valued
+ inputs, which allows those ops to accept numpy arrays, Python lists,
+ and scalars in addition to `Tensor` objects.
+
+ Args:
+ value: An object whose type has a registered `Tensor` conversion function.
+ dtype: Optional element type for the returned tensor. If missing, the
+ type is inferred from the type of `value`.
+ name: Optional name to use if a new `Tensor` is created.
+
+ Returns:
+ A `Tensor` based on `value`.
+
+ Raises:
+ TypeError: If no conversion function is registered for `value`.
+ RuntimeError: If a registered conversion function returns an invalid value.
+
+ """
+ error_prefix = "" if name is None else "%s: " % name
+ if dtype is not None:
+ dtype = types.as_dtype(dtype)
+ for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()):
+ for base_type, conversion_func in funcs_at_priority:
+ if isinstance(value, base_type):
+ ret = conversion_func(value, dtype=dtype, name=name)
+ if not isinstance(ret, Tensor):
+ raise RuntimeError(
+ "%sConversion function %r for type %s returned non-Tensor: %r"
+ % (error_prefix, conversion_func, base_type, ret))
+ if dtype and not dtype.is_compatible_with(ret.dtype):
+ raise RuntimeError(
+ "%sConversion function %r for type %s returned incompatible "
+ "dtype: requested = %s, actual = %s"
+ % (error_prefix, conversion_func, base_type,
+ dtype.name, ret.dtype.name))
+ return ret
+ raise TypeError("%sCannot convert %r with type %s to Tensor: "
+ "no conversion function registered."
+ % (error_prefix, value, type(value)))
+
+
+def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
+ """Converts the given object to a `Tensor` or an `IndexedSlices`.
+
+ If `value` is an `IndexedSlices` it is returned
+ unmodified. Otherwise, it is converted to a `Tensor` using
+ `convert_to_tensor()`.
+
+ Args:
+ value: An `IndexedSlices` or an object that can be consumed by
+ `convert_to_tensor()`.
+ dtype: (Optional.) The required `DType` of the returned `Tensor` or
+ `IndexedSlices`.
+ name: (Optional.) A name to use if a new `Tensor` is created.
+
+ Returns:
+ An `Tensor` or an `IndexedSlices` based on `value`.
+
+ Raises:
+ ValueError: If `dtype` does not match the element type of `value`.
+ """
+ if isinstance(value, IndexedSlices):
+ if dtype and not types.AsDType(dtype).is_compatible_with(value.dtype):
+ raise ValueError(
+ "Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
+ % (types.AsDType(dtype).name, value.dtype.name, str(value)))
+ return value
+ else:
+ return convert_to_tensor(value, dtype, name)
+
+
+def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
+ """Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
+
+ Args:
+ values: A list of `None`, `IndexedSlices`, or objects that can be consumed
+ by `convert_to_tensor()`.
+ dtype: (Optional.) The required `DType` of the returned `Tensor`
+ `IndexedSlices`.
+
+ name: (Optional.) A name prefix to used when a new `Tensor` is
+ created, in which case element `i` will be given the name `name
+ + '_' + i`.
+
+ Returns:
+ A list of `Tensor` and/or `IndexedSlices` objects.
+
+ Raises:
+ TypeError: If no conversion function is registered for an element in
+ `values`.
+ RuntimeError: If a registered conversion function returns an invalid
+ value.
+ """
+ if not isinstance(values, collections.Sequence):
+ raise TypeError("values must be a list.")
+ ret = []
+ for i, value in enumerate(values):
+ if value is None:
+ ret.append(value)
+ else:
+ n = None if name is None else "%s_%d" % (name, i)
+ ret.append(
+ convert_to_tensor_or_indexed_slices(value, dtype=dtype, name=n))
+ return ret
+
+
+def register_tensor_conversion_function(base_type, conversion_func,
+ priority=100):
+ """Registers a function for converting objects of base_type to Tensor.
+
+ The conversion function must have the following signature:
+
+ def conversion_func(value, dtype=None, name=None):
+ # ...
+
+ It must return a Tensor with the given dtype if specified. If the
+ conversion function creates a new Tensor, it should use the given
+ name if specified. All exceptions will be propagated to the caller.
+
+ NOTE: The conversion functions will execute in order of priority,
+ followed by order of registration. To ensure that a conversion
+ function F runs before another conversion function G, ensure that
+ F is registered with a smaller priority than G.
+
+ Args:
+ base_type: The base type or tuple of base types for all objects that
+ `conversion_func` accepts.
+ conversion_func: A function that converts instances of base_type to Tensor.
+ priority: Optional integer that indicates the priority for applying this
+ conversion function. Conversion functions with smaller priority values
+ run earlier than conversion functions with larger priority values.
+ Defaults to 100.
+
+ Raises:
+ TypeError: If the arguments do not have the appropriate type.
+
+ """
+ if not (isinstance(base_type, type) or
+ (isinstance(base_type, tuple)
+ and all(isinstance(x, type) for x in base_type))):
+ raise TypeError("base_type must be a type or a tuple of types.")
+ if not callable(conversion_func):
+ raise TypeError("conversion_func must be callable.")
+
+ try:
+ funcs_at_priority = _tensor_conversion_func_registry[priority]
+ except KeyError:
+ funcs_at_priority = []
+ _tensor_conversion_func_registry[priority] = funcs_at_priority
+ funcs_at_priority.append((base_type, conversion_func))
+
+
+class IndexedSlices(object):
+ """A sparse representation of a set of tensor slices at given indices.
+
+ This class is a simple wrapper for a pair of `Tensor` objects:
+
+ * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
+ * `indices`: A 1-D integer `Tensor` with shape `[D0]`.
+
+ An `IndexedSlices` is typically used to represent a subset of a larger
+ tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
+ The values in `indices` are the indices in the first dimension of
+ the slices that have been extracted from the larger tensor.
+
+ The dense tensor `dense` represented by an `IndexedSlices` `slices` has
+
+ ```python
+ dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
+ ```
+
+ The `IndexedSlices` class is used principally in the definition of
+ gradients for operations that have sparse gradients
+ (e.g. [`tf.gather`](array_ops.md#gather)).
+
+ Contrast this representation with
+ [`SparseTensor`](sparse_ops.md#SparseTensor),
+ which uses multi-dimensional indices and scalar values.
+
+ @@__init__
+
+ @@values
+ @@indices
+ @@dense_shape
+
+ @@name
+ @@dtype
+ @@device
+ @@op
+ """
+
+ def __init__(self, values, indices, dense_shape=None):
+ """Creates an `IndexedSlices`."""
+ self._values = values
+ self._indices = indices
+ self._dense_shape = dense_shape
+
+ @property
+ def values(self):
+ """A `Tensor` containing the values of the slices."""
+ return self._values
+
+ @property
+ def indices(self):
+ """A 1-D `Tensor` containing the indices of the slices."""
+ return self._indices
+
+ @property
+ def dense_shape(self):
+ """A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
+ return self._dense_shape
+
+ @property
+ def name(self):
+ """The name of this `IndexedSlices`."""
+ return self.values.name
+
+ @property
+ def device(self):
+ """The name of the device on which `values` will be produced, or `None`."""
+ return self.values.device
+
+ @property
+ def op(self):
+ """The `Operation` that produces `values` as an output."""
+ return self.values.op
+
+ @property
+ def dtype(self):
+ """The `DType` of elements in this tensor."""
+ return self.values.dtype
+
+ def __str__(self):
+ return "IndexedSlices(indices=%s, values=%s)" % (
+ self._indices, self._values)
+
+
+def assert_same_graph(items, expected_graph=None):
+ """Asserts all items are from the same graph.
+
+ Args:
+ items: List of graph items (e.g., Variable, Tensor, SparseTensor,
+ Operation, or IndexedSlices).
+ expected_graph: Expected graph. If not specified, assert all tensors are
+ from the same graph.
+ Returns:
+ items, for chaining.
+ Raises:
+ ValueError: If any graphs do not match.
+ """
+ for item in items:
+ if not expected_graph:
+ expected_graph = item.graph
+ elif expected_graph != item.graph:
+ raise ValueError("Items must be from the same graph.")
+ return items
+
+
+class SparseTensor(object):
+ """Represents a sparse tensor.
+
+ Tensorflow represents a sparse tensor as three separate dense tensors:
+ `indices`, `values`, and `dense_shape`. In Python, the three tensors are
+ collected into a `SparseTensor` class for ease of use. If you have separate
+ `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor`
+ object before passing to the Ops below.
+
+ Concretely, the sparse tensor `SparseTensor(values, indices, dense_shape)` is
+
+ * `indices`: A 2-D int64 tensor of shape `[N, ndims]`.
+ * `values`: A 1-D tensor of any type and shape `[N]`.
+ * `dense_shape`: A 1-D int64 tensor of shape `[ndims]`.
+
+ where `N` and `ndims` are the number of values, and number of dimensions in
+ the `SparseTensor` respectively.
+
+ The corresponding dense tensor satisfies
+
+ ```python
+ dense.shape = dense_shape
+ dense[tuple(indices[i])] = values[i]
+ ```
+
+ By convention, `indices` should be sorted in row-major order (or equivalently
+ lexigraphic order on the tuples `indices[i]`). This is not enforced when
+ `SparseTensor` objects are constructed, but most Ops assume correct ordering.
+ If the ordering is wrong, it can be fixed by calling `sparse_reorder` on the
+ misordered `SparseTensor`.
+
+ Example: The sparse tensor
+
+ ```python
+ SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4])
+ ```
+
+ represents the dense tensor
+
+ ```python
+ [[1, 0, 0, 0]
+ [0, 0, 2, 0]
+ [0, 0, 0, 0]]
+ ```
+
+ @@__init__
+ @@indices
+ @@values
+ @@dtype
+ @@shape
+ @@graph
+ """
+
+ def __init__(self, indices, values, shape):
+ """Creates a `SparseTensor`.
+
+ Args:
+ indices: A 2-D int64 tensor of shape `[N, ndims]`.
+ values: A 1-D tensor of any type and shape `[N]`.
+ dense_shape: A 1-D int64 tensor of shape `[ndims]`.
+
+ Returns:
+ A `SparseTensor`
+ """
+ with op_scope([indices, values, shape], None, "SparseTensor"):
+ indices = convert_to_tensor(indices, name="indices")
+ values = convert_to_tensor(values, name="values")
+ shape = convert_to_tensor(shape, name="shape")
+ self._indices = indices
+ self._values = values
+ self._shape = shape
+
+ indices_shape = indices.get_shape().with_rank(2)
+ values_shape = values.get_shape().with_rank(1)
+ shape_shape = shape.get_shape().with_rank(1)
+
+ # Assert number of rows in indices match the number of elements in values.
+ indices_shape[0].merge_with(values_shape[0])
+ # Assert number of columns in indices matches the number of elements in
+ # shape.
+ indices_shape[1].merge_with(shape_shape[0])
+
+ @property
+ def indices(self):
+ """The indices of non-zero values in the represented dense tensor.
+
+ Returns:
+ A 2-D Tensor of int64 with shape `[N, ndims]`, where `N` is the
+ number of non-zero values in the tensor, and `ndims` is the rank.
+ """
+ return self._indices
+
+ @property
+ def values(self):
+ """The non-zero values in the represented dense tensor.
+
+ Returns:
+ A 1-D Tensor of any data type.
+ """
+ return self._values
+
+ @property
+ def dtype(self):
+ """The `DType` of elements in this tensor."""
+ return self._values.dtype
+
+ @property
+ def shape(self):
+ """A 1-D Tensor of int64 representing the shape of the dense tensor."""
+ return self._shape
+
+ @property
+ def graph(self):
+ """The `Graph` that contains the index, value, and shape tensors."""
+ return self._indices.graph
+
+ def __str__(self):
+ return "SparseTensor(indices=%s, values=%s, shape=%s)" % (
+ self._indices, self._values, self._shape)
+
+
+SparseTensorValue = collections.namedtuple("SparseTensorValue",
+ ["indices", "values", "shape"])
+
+
+def _device_string(dev_spec):
+ if isinstance(dev_spec, pydev.Device):
+ return dev_spec.to_string()
+ else:
+ return dev_spec
+
+
+def _NodeDef(op_type, name, device=None, attrs=None):
+ """Create a NodeDef proto.
+
+ Args:
+ op_type: Value for the "op" attribute of the NodeDef proto.
+ name: Value for the "name" attribute of the NodeDef proto.
+ device: string, device, or function from NodeDef to string.
+ Value for the "device" attribute of the NodeDef proto.
+ attrs: optional list for the "attr" attribute of the NodeDef proto.
+
+ Returns:
+ A graph_pb2.NodeDef protocol buffer.
+ """
+ node_def = graph_pb2.NodeDef()
+ node_def.op = str(op_type)
+ node_def.name = str(name)
+ if attrs is not None:
+ for k, v in attrs.iteritems():
+ node_def.attr[k].CopyFrom(v)
+ if device is not None:
+ if callable(device):
+ node_def.device = device(node_def)
+ else:
+ node_def.device = _device_string(device)
+ return node_def
+
+
+# Copied from core/framework/node_def_util.cc
+# TODO(mrry,josh11b): Consolidate this validation in C++ code.
+_VALID_OP_NAME_REGEX = re.compile("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*")
+
+
+class Operation(object):
+ """Represents a graph node that performs computation on tensors.
+
+ An `Operation` is a node in a TensorFlow `Graph` that takes zero or
+ more `Tensor` objects as input, and produces zero or more `Tensor`
+ objects as output. Objects of type `Operation` are created by
+ calling a Python op constructor (such as [`tf.matmul()`](math_ops.md#matmul))
+ or [`Graph.create_op()`](framework.md#Graph.create_op).
+
+ For example `c = tf.matmul(a, b)` creates an `Operation` of type
+ "MatMul" that takes tensors `a` and `b` as input, and produces `c`
+ as output.
+
+ After the graph has been launched in a session, an `Operation` can
+ be executed by passing it to [`Session.run()`](client.md#Session.run).
+ `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
+
+ @@name
+ @@type
+ @@inputs
+ @@control_inputs
+ @@outputs
+ @@device
+ @@graph
+
+ @@run
+
+ @@get_attr
+ @@traceback
+ """
+
+ def __init__(self, node_def, g, inputs=None, output_types=None,
+ control_inputs=None, input_types=None, original_op=None,
+ op_def=None):
+ """Creates an `Operation`.
+
+ NOTE: This constructor validates the name of the Operation (passed
+ as "node_def.name"). Valid Operation names match the following
+ regular expression:
+
+ [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
+
+ Args:
+ node_def: graph_pb2.NodeDef. NodeDef for the Operation.
+ Used for attributes of graph_pb2.NodeDef, typically "name",
+ "op", and "device". The "input" attribute is irrelevant here
+ as it will be computed when generating the model.
+ g: Graph. The parent graph.
+ inputs: list of Tensor objects. The inputs to this Operation.
+ output_types: list of types_pb2.DataType. List of the types of the
+ Tensors computed by this operation. The length of this list indicates
+ the number of output endpoints of the Operation.
+ control_inputs: list of operations or tensors from which to have a
+ control dependency.
+ input_types: List of types_pb2.DataType representing the
+ types of the Tensors accepted by the Operation. By default
+ uses [x.dtype.base_dtype for x in inputs]. Operations that expect
+ reference-typed inputs must specify these explicitly.
+ original_op: Optional. Used to associate the new Operation with an
+ existing Operation (for example, a replica with the op that was
+ replicated).
+ op_def: Optional. The op_def_pb2.OpDef proto that describes the
+ op type that this Operation represents.
+
+ Raises:
+ TypeError: if control inputs are not Operations or Tensors,
+ or if node_def is not a NodeDef,
+ or if g is not a Graph,
+ or if inputs are not Tensors,
+ or if inputs and input_types are incompatible.
+ ValueError: if the node_def name is not valid.
+ """
+ if not isinstance(node_def, graph_pb2.NodeDef):
+ raise TypeError("node_def needs to be a NodeDef: %s" % node_def)
+ if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0:
+ raise ValueError(
+ "Cannot create an Operation with a NodeDef larger than 2GB.")
+ if not _VALID_OP_NAME_REGEX.match(node_def.name):
+ raise ValueError("'%s' is not a valid node name" % node_def.name)
+ if not isinstance(g, Graph):
+ raise TypeError("g needs to be a Graph: %s" % g)
+ self._node_def = copy.deepcopy(node_def)
+ self._graph = g
+ if inputs is None:
+ inputs = []
+ self._inputs = inputs
+ for a in self._inputs:
+ if not isinstance(a, Tensor):
+ raise TypeError("input needs to be a Tensor: %s" % a)
+ # Mark that we consume the inputs.
+ a._add_consumer(self) # pylint: disable=protected-access
+ if output_types is None:
+ output_types = []
+ self._output_types = output_types
+ self._outputs = [Tensor(self, i, output_types[i])
+ for i in xrange(len(output_types))]
+ if input_types is None:
+ input_types = [i.dtype.base_dtype for i in self._inputs]
+ else:
+ if not all(x.is_compatible_with(i.dtype)
+ for i, x in zip(self._inputs, input_types)):
+ raise TypeError("Inputs are not compatible with input types")
+ self._input_types = input_types
+
+ # Build the list of control inputs.
+ self._control_inputs = []
+ if control_inputs:
+ for c in control_inputs:
+ c_op = None
+ if isinstance(c, Operation):
+ c_op = c
+ elif isinstance(c, (Tensor, IndexedSlices)):
+ c_op = c.op
+ else:
+ raise TypeError("Control input must be an Operation, "
+ "a Tensor, or IndexedSlices: %s" % c)
+ self._control_inputs.append(c_op)
+
+ self._original_op = original_op
+ self._op_def = op_def
+ self._traceback = _extract_stack()
+ # Add this op to the current control flow context:
+ self._control_flow_context = g._get_control_flow_context()
+ if g._get_control_flow_context() is not None:
+ g._get_control_flow_context().AddOp(self)
+ # NOTE(keveman): Control flow context's AddOp could be creating new ops and
+ # setting op.inputs[index] = new_op. Thus the new ops' id could be larger
+ # than this op's id even though this op depend on them. Therefore, delaying
+ # assigning id to this op until all ops this could be dependent on are
+ # created.
+ self._id_value = self._graph._next_id() # pylint: disable=protected-access
+ self._recompute_node_def()
+
+ def values(self):
+ """DEPRECATED: Use outputs."""
+ return tuple(self.outputs)
+
+ def _get_control_flow_context(self):
+ """Returns the current control flow context.
+
+ Returns:
+ A context object.
+ """
+ return self._control_flow_context
+
+ @property
+ def name(self):
+ """The full name of this operation."""
+ return self._node_def.name
+
+ @property
+ def _id(self):
+ """The unique integer id of this operation."""
+ return self._id_value
+
+ @property
+ def device(self):
+ """The name of the device to which this op has been assigned, if any.
+
+ Returns:
+ The string name of the device to which this op has been
+ assigned, or None if it has not been assigned to a device.
+ """
+ dev = self._node_def.device
+ return None if not dev else dev
+
+ def _set_device(self, device):
+ """Set the device of this operation.
+
+ Args:
+ device: string or device.. The device to set.
+ """
+ self._node_def.device = _device_string(device)
+
+ def _add_input(self, tensor, dtype=None):
+ """Add a new input to this operation.
+
+ Args:
+ tensor: the Tensor to add as an input.
+ dtype: types.DType: type of the input; defaults to
+ the tensor's dtype.
+
+ Raises:
+ TypeError: if tensor is not a Tensor,
+ or if input tensor type is not convertible to dtype.
+ ValueError: if the Tensor is from a different graph.
+ """
+ if not isinstance(tensor, Tensor):
+ raise TypeError("tensor must be a Tensor: %s" % tensor)
+ assert_same_graph([self, tensor])
+ if dtype is None:
+ dtype = tensor.dtype
+ else:
+ dtype = types.as_dtype(dtype)
+ if not dtype.is_compatible_with(tensor.dtype):
+ raise TypeError(
+ "Cannot convert a tensor of type %s to an input of type %s"
+ % (tensor.dtype.name, dtype.name))
+ self._inputs.append(tensor)
+ self._input_types.append(dtype)
+ tensor._add_consumer(self) # pylint: disable=protected-access
+ self._recompute_node_def()
+
+ def _update_input(self, index, tensor, dtype=None):
+ """Update the input to this operation at the given index.
+
+ NOTE: This is for TF internal use only. Please don't use it.
+
+ Args:
+ index: the index of the input to update.
+ tensor: the Tensor to be used as the input at the given index.
+ dtype: types.DType: type of the input; defaults to
+ the tensor's dtype.
+
+ Raises:
+ TypeError: if tensor is not a Tensor,
+ or if input tensor type is not convertible to dtype.
+ ValueError: if the Tensor is from a different graph.
+ """
+ if not isinstance(tensor, Tensor):
+ raise TypeError("tensor must be a Tensor: %s" % tensor)
+ assert_same_graph([self, tensor])
+ if dtype is None:
+ dtype = tensor.dtype
+ else:
+ dtype = types.as_dtype(dtype)
+ if not dtype.is_compatible_with(tensor.dtype):
+ raise TypeError(
+ "Cannot convert a tensor of type %s to an input of type %s"
+ % (tensor.dtype.name, dtype.name))
+
+ self._inputs[index].consumers().remove(self)
+ self._inputs[index] = tensor
+ self._input_types[index] = dtype
+ tensor._add_consumer(self) # pylint: disable=protected-access
+ self._recompute_node_def()
+
+ def _add_control_input(self, op):
+ """Add a new control input to this operation.
+
+ Args:
+ op: the Operation to add as control input.
+
+ Raises:
+ TypeError: if op is not an Operation.
+ ValueError: if op is from a different graph.
+ """
+ if not isinstance(op, Operation):
+ raise TypeError("op must be an Operation: %s" % op)
+ assert_same_graph([self, op])
+ self._control_inputs.append(op)
+ self._recompute_node_def()
+
+ # Methods below are used when building the NodeDef and Graph proto.
+ def _recompute_node_def(self):
+ del self._node_def.input[:]
+ self._node_def.input.extend([t._as_node_def_input() for t in self._inputs])
+ if self._control_inputs:
+ self._node_def.input.extend(["^%s" % op.name for op in
+ self._control_inputs])
+
+ def __str__(self):
+ return str(self._node_def)
+
+ @property
+ def outputs(self):
+ """The list of `Tensor` objects representing the outputs of this op."""
+ return self._outputs
+
+# pylint: disable=protected-access
+ class _InputList(object):
+ """Immutable input list wrapper."""
+
+ def __init__(self, op):
+ self._op = op
+
+ def __iter__(self):
+ return iter(self._op._inputs)
+
+ def __len__(self):
+ return len(self._op._inputs)
+
+ def __bool__(self):
+ return bool(self._op._inputs)
+
+ def __getitem__(self, i):
+ return self._op._inputs[i]
+# pylint: enable=protected-access
+
+ @property
+ def inputs(self):
+ """The list of `Tensor` objects representing the data inputs of this op."""
+ return Operation._InputList(self)
+
+ @property
+ def _input_dtypes(self):
+ return self._input_types
+
+ @property
+ def control_inputs(self):
+ """The `Operation` objects on which this op has a control dependency.
+
+ Before this op is executed, TensorFlow will ensure that the
+ operations in `self.control_inputs` have finished executing. This
+ mechanism can be used to run ops sequentially for performance
+ reasons, or to ensure that the side effects of an op are observed
+ in the correct order.
+
+ Returns:
+ A list of `Operation` objects.
+
+ """
+ return self._control_inputs
+
+ @property
+ def type(self):
+ """The type of the op (e.g. `"MatMul"`)."""
+ return self._node_def.op
+
+ @property
+ def graph(self):
+ """The `Graph` that contains this operation."""
+ return self._graph
+
+ @property
+ def node_def(self):
+ """Returns a serialized `NodeDef` representation of this operation.
+
+ Returns:
+ A
+ [`NodeDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
+ protocol buffer.
+ """
+ return self._node_def
+
+ @property
+ def op_def(self):
+ """Returns the `OpDef` proto that represents the type of this op.
+
+ Returns:
+ An
+ [`OpDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_def.proto)
+ protocol buffer.
+ """
+ return self._op_def
+
+ @property
+ def traceback(self):
+ """Returns the call stack from when this operation was constructed."""
+ return _convert_stack(self._traceback)
+
+ def get_attr(self, name):
+ """Returns the value of the attr of this op with the given `name`.
+
+ Args:
+ name: The name of the attr to fetch.
+
+ Returns:
+ The value of the attr, as a Python object.
+
+ Raises:
+ ValueError: If this op does not have an attr with the given `name`.
+ """
+ fields = ["s", "i", "f", "b", "type", "shape", "tensor"]
+ if name not in self._node_def.attr:
+ raise ValueError("No attr named '" + name + "' in " +
+ str(self._node_def))
+ x = self._node_def.attr[name]
+ # Treat an empty oneof value as an empty list.
+ if not x.WhichOneof("value"):
+ return []
+ if x.HasField("list"):
+ for f in fields:
+ if getattr(x.list, f):
+ return list(getattr(x.list, f))
+ return []
+ else:
+ for f in fields:
+ if x.HasField(f):
+ return getattr(x, f)
+ assert False, "Unsupported field type in " + str(x)
+
+ def run(self, feed_dict=None, session=None):
+ """Runs this operation in a `Session`.
+
+ Calling this method will execute all preceding operations that
+ produce the inputs needed for this operation.
+
+ *N.B.* Before invoking `Operation.run()`, its graph must have been
+ launched in a session, and either a default session must be
+ available, or `session` must be specified explicitly.
+
+ Args:
+ feed_dict: A dictionary that maps `Tensor` objects to feed values.
+ See [`Session.run()`](client.md#Session.run) for a description of the
+ valid feed values.
+ session: (Optional.) The `Session` to be used to run to this operation. If
+ none, the default session will be used.
+ """
+ _run_using_default_session(self, feed_dict, self.graph, session)
+
+
+_gradient_registry = registry.Registry("gradient")
+
+
+class RegisterGradient(object):
+ """A decorator for registering the gradient function for an op type.
+
+ This decorator is only used when defining a new op type. For an op
+ with `m` inputs and `n` inputs, the gradient function is a function
+ that takes the original `Operation` and `n` `Tensor` objects
+ (representing the gradients with respect to each output of the op),
+ and returns `m` `Tensor` objects (representing the partial gradients
+ with respect to each input of the op).
+
+ For example, assuming that operations of type `"Sub"` take two
+ inputs `x` and `y`, and return a single output `x - y`, the
+ following gradient function would be registered:
+
+ ```python
+ @tf.RegisterGradient("Sub")
+ def _sub_grad(unused_op, grad):
+ return grad, tf.Neg(grad)
+ ```
+
+ The decorator argument `op_type` is the string type of an
+ operation. This corresponds to the `OpDef.name` field for the proto
+ that defines the operation.
+
+ @@__init__
+ """
+
+ def __init__(self, op_type):
+ """Creates a new decorator with `op_type` as the Operation type.
+
+ Args:
+ op_type: The string type of an operation. This corresponds to the
+ `OpDef.name` field for the proto that defines the operation.
+ """
+ if not isinstance(op_type, basestring):
+ raise TypeError("op_type must be a string")
+ self._op_type = op_type
+
+ def __call__(self, f):
+ """Registers the function `f` as gradient function for `op_type`."""
+ _gradient_registry.register(f, self._op_type)
+ return f
+
+
+def NoGradient(op_type):
+ """Specifies that ops of type `op_type` do not have a defined gradient.
+
+ This function is only used when defining a new op type. It may be
+ used for ops such as `tf.size()` that are not differentiable. For
+ example:
+
+ ```python
+ tf.NoGradient("Size")
+ ```
+
+ Args:
+ op_type: The string type of an operation. This corresponds to the
+ `OpDef.name` field for the proto that defines the operation.
+
+ Raises:
+ TypeError: If `op_type` is not a string.
+
+ """
+ if not isinstance(op_type, basestring):
+ raise TypeError("op_type must be a string")
+ _gradient_registry.register(None, op_type)
+
+
+def get_gradient_function(op):
+ """Returns the function that computes gradients for "op"."""
+ if not op.inputs: return None
+ try:
+ op_type = op.get_attr("_gradient_op_type")
+ except ValueError:
+ op_type = op.type
+ return _gradient_registry.lookup(op_type)
+
+
+_shape_registry = registry.Registry("shape functions")
+_default_shape_function_registry = registry.Registry("default shape functions")
+
+class RegisterShape(object):
+ """A decorator for registering the shape function for an op type.
+
+ This decorator is only used when defining a new op type. A shape
+ function is a function from an `Operation` object to a list of
+ `TensorShape` objects, with one `TensorShape` for each output of the
+ operation.
+
+ For example, assuming that operations of type `"Sub"` take two
+ inputs `x` and `y`, and return a single output `x - y`, all with the
+ same shape, the following shape function would be registered:
+
+ ```python
+ @tf.RegisterShape("Sub")
+ def _sub_shape(op):
+ return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())]
+ ```
+
+ The decorator argument `op_type` is the string type of an
+ operation. This corresponds to the `OpDef.name` field for the proto
+ that defines the operation.
+
+ """
+
+ def __init__(self, op_type):
+ """Saves the "op_type" as the Operation type."""
+ if not isinstance(op_type, basestring):
+ raise TypeError("op_type must be a string")
+ self._op_type = op_type
+
+ def __call__(self, f):
+ """Registers "f" as the shape function for "op_type"."""
+ if f is None:
+ # None is a special "weak" value that provides a default shape function,
+ # and can be overridden by a non-None registration.
+ try:
+ _default_shape_function_registry.register(_no_shape_function,
+ self._op_type)
+ except KeyError:
+ # Ignore duplicate registrations of the weak value. This can
+ # occur if the op library input to wrapper generation
+ # inadvertently links in one or more of the standard op
+ # libraries.
+ pass
+ else:
+ _shape_registry.register(f, self._op_type)
+ return f
+
+
+def _no_shape_function(op):
+ return [tensor_shape.unknown_shape() for _ in op.outputs]
+
+
+def set_shapes_for_outputs(op):
+ """Uses the registered shape functions to set the shapes for op's outputs."""
+ try:
+ shape_func = _shape_registry.lookup(op.type)
+ except LookupError:
+ try:
+ shape_func = _default_shape_function_registry.lookup(op.type)
+ except LookupError:
+ raise RuntimeError("No shape function registered for standard op: %s"
+ % op.type)
+ shapes = shape_func(op)
+ if len(op.outputs) != len(shapes):
+ raise RuntimeError(
+ "Shape function for op %s returned %g shapes but expecting %g" %
+ (op, len(op.outputs), len(shapes)))
+ for output, s in zip(op.outputs, shapes):
+ output.set_shape(s)
+
+
+class Graph(object):
+ """A TensorFlow computation, represented as a dataflow graph.
+
+ A `Graph` contains a set of [`Operation`](framework.md#Operation) objects,
+ which represent units of computation; and [`Tensor`](framework.md#Tensor)
+ objects, which represent the units of data that flow between operations.
+
+ A default `Graph` is always registered, and accessible by calling
+ [`tf.get_default_graph()`](framework.md#get_default_graph). To add an
+ operation to the default graph, simply call one of the functions that defines
+ a new `Operation`:
+
+ ```
+ c = tf.constant(4.0)
+ assert c.graph is tf.get_default_graph()
+ ```
+
+ Another typical usage involves the
+ [`Graph.as_default()`](framework.md#Graph.as_default)
+ context manager, which overrides the current default graph for the
+ lifetime of the context:
+
+ ```python
+ g = tf.Graph()
+ with g.as_default():
+ # Define operations and tensors in `g`.
+ c = tf.constant(30.0)
+ assert c.graph is g
+ ```
+
+ Important note: This class *is not* thread-safe for graph construction. All
+ operations should be created from a single thread, or external
+ synchronization must be provided. Unless otherwise specified, all methods
+ are not thread-safe.
+
+ @@__init__
+ @@as_default
+ @@as_graph_def
+ @@finalize
+ @@finalized
+
+ @@control_dependencies
+ @@device
+ @@name_scope
+
+ A `Graph` instance supports an arbitrary number of "collections"
+ that are identified by name. For convenience when building a large
+ graph, collections can store groups of related objects: for
+ example, the `tf.Variable` uses a collection (named
+ [`tf.GraphKeys.VARIABLES`](framework.md#GraphKeys)) for all variables that are
+ created during the construction of a graph. The caller may define
+ additional collections by specifying a new name.
+
+ @@add_to_collection
+ @@get_collection
+
+ @@as_graph_element
+ @@get_operation_by_name
+ @@get_tensor_by_name
+ @@get_operations
+
+ @@get_default_device
+ @@seed
+ @@unique_name
+ @@version
+
+ @@create_op
+ @@gradient_override_map
+ """
+
+ def __init__(self):
+ """Creates a new, empty Graph."""
+ self._nodes_by_id = dict()
+ self._next_node_id = [dict()]
+ self._next_id_counter = 0
+ self._nodes_by_name = dict()
+ # Current name stack: a pair of uniquified names and plain names.
+ self._name_stack = ("", "")
+ # Maps a name used in the graph to the next id to use for that name.
+ self._names_in_use = {}
+ # Default device applied to new ops.
+ self._default_device = None
+ # Functions that will be applied to choose a device if none is specified.
+ self._device_function_stack = []
+ # Default original_op applied to new ops.
+ self._default_original_op = None
+ # Current control flow context. It could be either CondContext or
+ # WhileContext defined in ops/control_flow_ops.py
+ self._control_flow_context = None
+ # A new node will depend of the union of all of the nodes in the stack.
+ self._control_dependencies_stack = []
+ # Arbritrary collections of objects.
+ self._collections = {}
+ # The graph-level random seed
+ self._seed = None
+ # A map from op type to the kernel label that should be used.
+ self._op_to_kernel_label_map = {}
+ # A map from op type to an alternative op type that should be used when
+ # computing gradients.
+ self._gradient_override_map = {}
+ # True if the graph is considered "finalized". In that case no
+ # new operations can be added.
+ self._finalized = False
+
+ def _check_not_finalized(self):
+ """Check if the graph is finalized.
+
+ Raises:
+ RuntimeError: If the graph finalized.
+ """
+ if self._finalized:
+ raise RuntimeError("Graph is finalized and cannot be modified.")
+
+ def _add_op(self, op):
+ """Adds 'op' to the graph.
+
+ Args:
+ op: the Operator or Tensor to add.
+
+ Raises:
+ TypeError: if op is not an Operation or Tensor.
+ ValueError: if the op.name or op._id are already used.
+ """
+ self._check_not_finalized()
+ if not isinstance(op, (Tensor, Operation)):
+ raise TypeError("op must be a Tensor or Operation: %s" % op)
+
+ if op._id in self._nodes_by_id:
+ raise ValueError("cannot add an op with id %d as it already "
+ "exists in the graph" % op._id)
+ if op.name in self._nodes_by_name:
+ raise ValueError("cannot add op with name %s as that name "
+ "is already used" % op.name)
+ self._nodes_by_id[op._id] = op
+ self._nodes_by_name[op.name] = op
+
+ @property
+ def version(self):
+ """Returns a version number that increases as ops are added to the graph."""
+ return self._next_id_counter
+
+ @property
+ def seed(self):
+ return self._seed
+
+ @seed.setter
+ def seed(self, seed):
+ self._seed = seed
+
+ @property
+ def finalized(self):
+ """True if this graph has been finalized."""
+ return self._finalized
+
+ def finalize(self):
+ """Finalizes this graph, making it read-only.
+
+ After calling `g.finalize()`, no new operations can be added to
+ `g`. This method is used to ensure that no operations are added
+ to a graph when it is shared between multiple threads, for example
+ when using a [`QueueRunner`](train.md#QueueRunner).
+ """
+ self._finalized = True
+
+ def _get_control_flow_context(self):
+ """Returns the current control flow context.
+
+ Returns:
+ A context object.
+ """
+ return self._control_flow_context
+
+ def _set_control_flow_context(self, context):
+ """Sets the current control flow context.
+
+ Args:
+ context: a context object.
+ """
+ self._control_flow_context = context
+
+ def as_graph_def(self, from_version=None):
+ """Returns a serialized `GraphDef` representation of this graph.
+
+ This method is thread-safe.
+
+ Args:
+ from_version: Optional. If this is set, returns a `GraphDef`
+ containing only the nodes that were added to this graph since
+ its `version` property had the given value.
+
+ Returns:
+ A
+ [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
+ protocol buffer.
+ """
+ graph = graph_pb2.GraphDef()
+ bytesize = 0
+ for op_id in sorted(self._nodes_by_id):
+ op = self._nodes_by_id[op_id]
+ if from_version is None or op_id > from_version:
+ graph.node.extend([op.node_def])
+ bytesize += op.node_def.ByteSize()
+ if bytesize >= (1 << 31) or bytesize < 0:
+ raise ValueError("GraphDef cannot be larger than 2GB.")
+ return graph
+
+ # Helper functions to create operations.
+ def create_op(self, op_type, inputs, dtypes,
+ input_types=None, name=None, attrs=None, op_def=None,
+ compute_shapes=True):
+ """Creates an `Operation` in this graph.
+
+ This is a low-level interface for creating an `Operation`. Most
+ programs will not call this method directly, and instead use the
+ Python op constructors, such as `tf.constant()`, which add ops to
+ the default graph.
+
+ Args:
+ op_type: The `Operation` type to create. This corresponds to the
+ `OpDef.name` field for the proto that defines the operation.
+ inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
+ dtypes: A list of `DType` objects that will be the types of the tensors
+ that the operation produces.
+ input_types: (Optional.) A list of `DType`s that will be the types of
+ the tensors that the operation consumes. By default, uses the base
+ `DType` of each input in `inputs`. Operations that expect
+ reference-typed inputs must specify `input_types` explicitly.
+ name: (Optional.) A string name for the operation. If not specified, a
+ name is generated based on `op_type`.
+ attrs: (Optional.) A list of `AttrValue` protos for the `attr` field of
+ the `NodeDef` proto that will represent the operation.
+ op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
+ the operation will have.
+ compute_shapes: (Optional.) If True, shape inference will be performed
+ to compute the shapes of the outputs.
+
+ Raises:
+ TypeError: if any of the inputs is not a `Tensor`.
+
+ Returns:
+ An `Operation` object.
+
+ """
+ self._check_not_finalized()
+ for idx, a in enumerate(inputs):
+ if not isinstance(a, Tensor):
+ raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
+ if name is None:
+ name = op_type
+ # If a names ends with a '/' it is a "name scope" and we use it as-is,
+ # after removing the trailing '/'.
+ if name and name[-1] == "/":
+ name = name[:-1]
+ else:
+ name = self.unique_name(name)
+
+ node_def = _NodeDef(
+ op_type, name, device=self._default_device or None, attrs=attrs)
+
+ # Apply a kernel label if one has been specified for this op_type.
+ try:
+ kernel_label = self._op_to_kernel_label_map[op_type]
+ node_def.attr["_kernel"].CopyFrom(
+ attr_value_pb2.AttrValue(s=kernel_label))
+ except KeyError:
+ pass
+
+ # Apply the overriding op_type for gradients if one has been
+ # specified for this op_type.
+ try:
+ mapped_op_type = self._gradient_override_map[op_type]
+ node_def.attr["_gradient_op_type"].CopyFrom(
+ attr_value_pb2.AttrValue(s=mapped_op_type))
+ except KeyError:
+ pass
+
+ control_inputs = self._control_dependencies_for_inputs(inputs)
+ ret = Operation(node_def, self, inputs=inputs, output_types=dtypes,
+ control_inputs=control_inputs, input_types=input_types,
+ original_op=self._default_original_op, op_def=op_def)
+ if compute_shapes:
+ set_shapes_for_outputs(ret)
+ self._add_op(ret)
+ self._record_op_seen_by_control_dependencies(ret)
+ # Apply any device functions in reverse order, so that the most recently
+ # pushed function has the first chance to apply a device to the op.
+ # We apply here because the result can depend on the Operation's
+ # signature, which is computed in the Operation constructor.
+ for device_function in reversed(self._device_function_stack):
+ ret._set_device(device_function(ret))
+ return ret
+
+ def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
+ """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
+
+ This function validates that `obj` represents an element of this
+ graph, and gives an informative error message if it is not.
+
+ This function is the canonical way to get/validate an object of
+ one of the allowed types from an external argument reference in the
+ Session API.
+
+ This method may be called concurrently from multiple threads.
+
+ Args:
+ obj: A `Tensor`, an `Operation`, or the name of a tensor or operation.
+ Can also be any object with an `_as_graph_element()` method that returns
+ a value of one of these types.
+ allow_tensor: If true, `obj` may refer to a `Tensor`.
+ allow_operation: If true, `obj` may refer to an `Operation`.
+
+ Returns:
+ The `Tensor` or `Operation` in the Graph corresponding to `obj`.
+
+ Raises:
+ TypeError: If `obj` is not a type we support attempting to convert
+ to types.
+ ValueError: If `obj` is of an appropriate type but invalid. For
+ example, an invalid string.
+ KeyError: If `obj` is not an object in the graph.
+ """
+
+ # The vast majority of this function is figuring
+ # out what an API user might be doing wrong, so
+ # that we can give helpful error messages.
+ #
+ # Ideally, it would be nice to split it up, but we
+ # need context to generate nice error messages.
+
+ if allow_tensor and allow_operation:
+ types_str = "Tensor or Operation"
+ elif allow_tensor:
+ types_str = "Tensor"
+ elif allow_operation:
+ types_str = "Operation"
+ else:
+ raise ValueError("allow_tensor and allow_operation can't both be False.")
+
+ conv_fn = getattr(obj, "_as_graph_element", None)
+ if conv_fn and callable(conv_fn):
+ obj = conv_fn()
+
+ # If obj appears to be a name...
+ if isinstance(obj, basestring):
+ name = obj
+
+ if ":" in name and allow_tensor:
+ # Looks like a Tensor name and can be a Tensor.
+ try:
+ op_name, out_n = name.split(":")
+ out_n = int(out_n)
+ except:
+ raise ValueError("The name %s looks a like a Tensor name, but is "
+ "not a valid one. Tensor names must be of the "
+ "form \"<op_name>:<output_index>\"." % repr(name))
+ if op_name in self._nodes_by_name:
+ op = self._nodes_by_name[op_name]
+ else:
+ raise KeyError("The name %s refers to a Tensor which does not "
+ "exist. The operation, %s, does not exist in the "
+ "graph." % (repr(name), repr(op_name)))
+ try:
+ return op.outputs[out_n]
+ except:
+ raise KeyError("The name %s refers to a Tensor which does not "
+ "exist. The operation, %s, exists but only has "
+ "%s outputs."
+ % (repr(name), repr(op_name), len(op.outputs)))
+
+ elif ":" in name and not allow_tensor:
+ # Looks like a Tensor name but can't be a Tensor.
+ raise ValueError("Name %s appears to refer to a Tensor, not a %s."
+ % (repr(name), types_str))
+
+ elif ":" not in name and allow_operation:
+ # Looks like an Operation name and can be an Operation.
+ if name not in self._nodes_by_name:
+ raise KeyError("The name %s refers to an Operation not in the "
+ "graph." % repr(name))
+ return self._nodes_by_name[name]
+
+ elif ":" not in name and not allow_operation:
+ # Looks like an Operation name but can't be an Operation.
+ if name in self._nodes_by_name:
+ # Yep, it's an Operation name
+ err_msg = ("The name %s refers to an Operation, not a %s."
+ % (repr(name), types_str))
+ else:
+ err_msg = ("The name %s looks like an (invalid) Operation name, "
+ "not a %s." % (repr(name), types_str))
+ err_msg += (" Tensor names must be of the form "
+ "\"<op_name>:<output_index>\".")
+ raise ValueError(err_msg)
+
+ elif isinstance(obj, Tensor) and allow_tensor:
+ # Actually obj is just the object it's referring to.
+ return obj
+ elif isinstance(obj, Operation) and allow_operation:
+ # Actually obj is just the object it's referring to.
+ return obj
+ else:
+ # We give up!
+ raise TypeError("Can not convert a %s into a %s."
+ % (type(obj).__name__, types_str))
+
+ def get_operations(self):
+ """Return the list of operations in the graph.
+
+ You can modify the operations in place, but modifications
+ to the list such as inserts/delete have no effect on the
+ list of operations known to the graph.
+
+ This method may be called concurrently from multiple threads.
+
+ Returns:
+ A list of Operations.
+ """
+ return self._nodes_by_id.values()
+
+ def get_operation_by_name(self, name):
+ """Returns the `Operation` with the given `name`.
+
+ This method may be called concurrently from multiple threads.
+
+ Args:
+ name: The name of the `Operation` to return.
+
+ Returns:
+ The `Operation` with the given `name`.
+
+ Raises:
+ TypeError: If `name` is not a string.
+ KeyError: If `name` does not correspond to an operation in this graph.
+ """
+
+ if not isinstance(name, basestring):
+ raise TypeError("Operation names are strings (or similar), not %s."
+ % type(name).__name__)
+ return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
+
+ def get_tensor_by_name(self, name):
+ """Returns the `Tensor` with the given `name`.
+
+ This method may be called concurrently from multiple threads.
+
+ Args:
+ name: The name of the `Tensor` to return.
+
+ Returns:
+ The `Tensor` with the given `name`.
+
+ Raises:
+ TypeError: If `name` is not a string.
+ KeyError: If `name` does not correspond to a tensor in this graph.
+ """
+ # Names should be strings.
+ if not isinstance(name, basestring):
+ raise TypeError("Tensor names are strings (or similar), not %s."
+ % type(name).__name__)
+ return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
+
+ def _next_id(self):
+ """Id for next Operation instance. Also increments the internal id."""
+ self._check_not_finalized()
+ self._next_id_counter += 1
+ return self._next_id_counter
+
+ @property
+ def _last_id(self):
+ return self._next_id_counter
+
+ def as_default(self):
+ """Returns a context manager that makes this `Graph` the default graph.
+
+ This method should be used if you want to create multiple graphs
+ in the same process. For convenience, a global default graph is
+ provided, and all ops will be added to this graph if you do not
+ create a new graph explicitly. Use this method the `with` keyword
+ to specify that ops created within the scope of a block should be
+ added to this graph.
+
+ The default graph is a property of the current thread. If you
+ create a new thread, and wish to use the default graph in that
+ thread, you must explicitly add a `with g.as_default():` in that
+ thread's function.
+
+ The following code examples are equivalent:
+
+ ```python
+ # 1. Using Graph.as_default():
+ g = tf.Graph()
+ with g.as_default():
+ c = tf.constant(5.0)
+ assert c.graph is g
+
+ # 2. Constructing and making default:
+ with tf.Graph().as_default() as g:
+ c = tf.constant(5.0)
+ assert c.graph is g
+ ```
+
+ Returns:
+ A context manager for using this graph as the default graph.
+ """
+ return _default_graph_stack.get_controller(self)
+
+ def add_to_collection(self, name, value):
+ """Stores `value` in the collection with the given `name`.
+
+ Args:
+ name: The key for the collection. For example, the `GraphKeys` class
+ contains many standard names for collections.
+ value: The value to add to the collection.
+ """
+ self._check_not_finalized()
+ if name not in self._collections:
+ self._collections[name] = [value]
+ else:
+ self._collections[name].append(value)
+
+ def get_collection(self, name, scope=None):
+ """Returns a list of values in the collection with the given `name`.
+
+ Args:
+ key: The key for the collection. For example, the `GraphKeys` class
+ contains many standard names for collections.
+ scope: (Optional.) If supplied, the resulting list is filtered to include
+ only items whose name begins with this string.
+
+ Returns:
+ The list of values in the collection with the given `name`, or
+ an empty list if no value has been added to that collection. The
+ list contains the values in the order under which they were
+ collected.
+ """
+ if scope is None:
+ return self._collections.get(name, list())
+ else:
+ c = []
+ for item in self._collections.get(name, list()):
+ if hasattr(item, 'name') and item.name.startswith(scope):
+ c.append(item)
+ return c
+
+ @contextlib.contextmanager
+ def _original_op(self, op):
+ """Python 'with' handler to help annotate ops with their originator.
+
+ An op may have an 'original_op' property that indicates the op on which
+ it was based. For example a replica op is based on the op that was
+ replicated and a gradient op is based on the op that was differentiated.
+
+ All ops created in the scope of this 'with' handler will have
+ the given 'op' as their original op.
+
+ Args:
+ op: The Operation that all ops created in this scope will have as their
+ original op.
+
+ Yields:
+ Nothing.
+ """
+ old_original_op = self._default_original_op
+ try:
+ self._default_original_op = op
+ yield
+ finally:
+ self._default_original_op = old_original_op
+
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
+ def name_scope(self, name):
+ """Returns a context manager that creates hierarchical names for operations.
+
+ A graph maintains a stack of name scopes. A `with name_scope(...):`
+ statement pushes a new name onto the stack for the lifetime of the context.
+
+ The `name` argument will be interpreted as follows:
+
+ * A string (not ending with '/') will create a new name scope, in which
+ `name` is appended to the prefix of all operations created in the
+ context. If `name` has been used before, it will be made unique by
+ calling `self.unique_name(name)`.
+ * A scope previously captured from a `with g.name_scope(...) as
+ scope:` statement will be treated as an "absolute" name scope, which
+ makes it possible to re-enter existing scopes.
+ * A value of `None` or the empty string will reset the current name scope
+ to the top-level (empty) name scope.
+
+ For example:
+
+ ```python
+ with tf.Graph().as_default() as g:
+ c = tf.constant(5.0, name="c")
+ assert c_1.name == "c"
+ c_1 = tf.constant(6.0, name="c")
+ assert c_1.name == "c_1"
+
+ # Creates a scope called "nested"
+ with g.name_scope("nested") as scope:
+ nested_c = tf.constant(10.0, name="c")
+ assert nested_c.name == "nested/c"
+
+ # Creates a nested scope called "inner".
+ with g.name_scope("inner"):
+ nested_inner_c = tf.constant(20.0, name="c")
+ assert nested_inner_c.name == "nested/inner/c"
+
+ # Create a nested scope called "inner_1".
+ with g.name_scope("inner"):
+ nested_inner_1_c = tf.constant(30.0, name="c")
+ assert nested_inner_1_c.name == "nested/inner_1/c"
+
+ # Treats `scope` as an absolute name scope, and
+ # switches to the "nested/" scope.
+ with g.name_scope(scope):
+ nested_d = tf.constant(40.0, name="d")
+ assert nested_d.name == "nested/d"
+
+ with g.name_scope(""):
+ e = tf.constant(50.0, name="e")
+ assert e.name == "e"
+ ```
+
+ The name of the scope itself can be captured by `with
+ g.name_scope(...) as scope:`, which stores the name of the scope
+ in the variable `scope`. This value can be used to name an
+ operation that represents the overall result of executing the ops
+ in a scope. For example:
+
+ ```python
+ inputs = tf.constant(...)
+ with g.name_scope('my_layer') as scope:
+ weights = tf.Variable(..., name="weights")
+ biases = tf.Variable(..., name="biases")
+ affine = tf.matmul(inputs, weights) + biases
+ output = tf.nn.relu(affine, name=scope)
+ ```
+
+
+ Args:
+ name: A name for the scope.
+
+ Returns:
+ A context manager that installs `name` as a new name scope.
+ """
+ try:
+ old_stack = self._name_stack
+ if not name: # Both for name=None nad name="" we re-set to empty scope.
+ new_stack = (None, None)
+ elif name and name[-1] == "/":
+ new_stack = (name[:-1], name[:-1])
+ else:
+ new_stack = (self.unique_name(name), self._plain_name(name))
+ self._name_stack = new_stack
+ yield "" if new_stack[0] is None else new_stack[0] + "/"
+ finally:
+ self._name_stack = old_stack
+ # pylint: enable=g-doc-return-or-yield
+
+ def unique_name(self, name):
+ """Return a unique Operation name for "name".
+
+ Note: You rarely need to call unique_name() directly. Most of the time you
+ just need to create "with g.name_scope()" blocks to generate structured
+ names.
+
+ `unique_name` is used to generate structured names, separated by "/",
+ to help identify Operations when debugging a Graph. Operation names
+ are displayed in error messages reported by the TensorFlow runtime,
+ and in various visualization tools such as TensorBoard.
+
+ Args:
+ name: The name for an `Operation`.
+
+ Returns:
+ A string to be passed to `create_op()` that will be used
+ to name the operation being created.
+ """
+ if self._name_stack[0]:
+ name = self._name_stack[0] + "/" + name
+ i = self._names_in_use.get(name, 0)
+ # Increment the number for "name".
+ self._names_in_use[name] = i + 1
+ if i > 0:
+ base_name = name
+ # Make sure the composed name is not already used.
+ while name in self._names_in_use:
+ name = "%s_%d" % (base_name, i)
+ i += 1
+ # Mark the composed name as used in case someone wants
+ # to call unique_name("name_1").
+ self._names_in_use[name] = 1
+ return name
+
+ # TODO(mdevin): remove
+ def _plain_name(self, name):
+ """Return the fully scoped 'name'.
+
+ Args:
+ name: a string.
+
+ Returns:
+ 'name' scoped in the current name stack, without any uniquified
+ elements.
+ """
+ if self._name_stack[1]:
+ return self._name_stack[1] + "/" + name
+ else:
+ return name
+
+ def _set_default_device(self, dev):
+ """Set the default device properties.
+
+ Args:
+ dev: string or Device.
+ """
+ self._default_device = _device_string(dev)
+
+ def get_default_device(self):
+ """Returns the default device.
+
+ Returns:
+ A string.
+ """
+ return self._default_device
+
+ def _push_default_device_function(self, device_function):
+ """Pushes the given function onto the stack of device functions.
+
+ See Graph.device for more details.
+
+ Args:
+ device_function: The function to be pushed onto the stack of device
+ functions.
+ """
+ self._device_function_stack.append(device_function)
+
+ def _pop_default_device_function(self, device_function):
+ """Pops the given function from the stack of device functions.
+
+ See Graph.device for more details.
+
+ Args:
+ device_function: The function to be popped from the stack of device
+ functions.
+
+ Raises:
+ ValueError: if the device_function to be popped is not top of the stack,
+ or if the stack is empty.
+ """
+ if not self._device_function_stack:
+ raise ValueError("Tried to pop, but the device function stack is empty")
+ if self._device_function_stack[-1] is not device_function:
+ raise ValueError("Tried to pop device function, but it was not on top "
+ "of the stack")
+
+ self._device_function_stack.pop()
+
+ @contextlib.contextmanager
+ def device(self, device_name_or_function):
+ """Returns a context manager that specifies the default device to use.
+
+ The `device_name_or_function` argument may either be a device name
+ string, a device function, or None:
+
+ * If it is a device name string, all operations constructed in
+ this context will be assigned to the device with that name.
+ * If it is a function, it will be treated as function from
+ Operation objects to device name strings, and invoked each time
+ a new Operation is created. The Operation will be assigned to
+ the device with the returned name.
+ * If it is None, the default device will be cleared.
+
+ For example:
+
+ ```python
+ with g.device('/gpu:0'):
+ # All operations constructed in this context will be placed
+ # on GPU 0.
+ with g.device(None):
+ # All operations constructed in this context will have no
+ # assigned device.
+
+ # Defines a function from `Operation` to device string.
+ def matmul_on_gpu(n):
+ if n.type == "MatMul":
+ return "/gpu:0"
+ else:
+ return "/cpu:0"
+
+ with g.device(matmul_on_gpu):
+ # All operations of type "MatMul" constructed in this context
+ # will be placed on GPU 0; all other operations will be placed
+ # on CPU 0.
+ ```
+
+ Args:
+ device_name_or_function: The device name or function to use in
+ the context.
+
+ Returns:
+ A context manager that specifies the default device to use for newly
+ created ops.
+ """
+ if callable(device_name_or_function):
+ try:
+ self._push_default_device_function(device_name_or_function)
+ yield
+ finally:
+ self._pop_default_device_function(device_name_or_function)
+ else:
+ try:
+ old_dev = self.get_default_device()
+ self._set_default_device(_device_string(device_name_or_function))
+ yield
+ finally:
+ self._set_default_device(old_dev)
+
+ class _ControlDependenciesController(object):
+ """Context manager for `control_dependencies()`."""
+
+ def __init__(self, graph, control_inputs):
+ self._graph = graph
+ self._control_inputs = control_inputs
+ self._seen_nodes = set()
+
+# pylint: disable=protected-access
+ def __enter__(self):
+ self._graph._push_control_dependencies_controller(self)
+
+ def __exit__(self, unused_type, unused_value, unused_traceback):
+ self._graph._pop_control_dependencies_controller(self)
+# pylint: enable=protected-access
+
+ @property
+ def control_inputs(self):
+ return self._control_inputs
+
+ def add_op(self, op):
+ self._seen_nodes.add(op)
+
+ def op_in_group(self, op):
+ return op in self._seen_nodes
+
+ def _push_control_dependencies_controller(self, controller):
+ self._control_dependencies_stack.append(controller)
+
+ def _pop_control_dependencies_controller(self, controller):
+ assert self._control_dependencies_stack[-1] is controller
+ self._control_dependencies_stack.pop()
+
+ def _current_control_dependencies(self):
+ ret = set()
+ for controller in self._control_dependencies_stack:
+ for op in controller.control_inputs:
+ ret.add(op)
+ return ret
+
+ def _control_dependencies_for_inputs(self, input_tensors):
+ """For an op that takes `input_tensors` as inputs, compute control inputs.
+
+ The returned control dependencies should yield an execution that
+ is equivalent to adding all control inputs in
+ self._control_dependencies_stack to a newly created op. However,
+ this function attempts to prune the returned control dependencies
+ by observing that nodes created within the same `with
+ control_dependencies(...):` block may have data dependencies that make
+ the explicit approach redundant.
+
+ Args:
+ input_tensors: The direct data dependencies for an op to be created.
+
+ Returns:
+ A list of control inputs for the op to be created.
+ """
+ ret = []
+ input_ops = set([t.op for t in input_tensors])
+ for controller in self._control_dependencies_stack:
+ # If any of the input_ops already depends on the inputs from controller,
+ # we say that the new op is dominated (by that input), and we therefore
+ # do not need to add control dependences for this controller's inputs.
+ dominated = False
+ for op in input_ops:
+ if controller.op_in_group(op):
+ dominated = True
+ break
+ if not dominated:
+ # Don't add a control input if we already have a data dependency on i.
+ # NOTE(mrry): We do not currently track transitive data dependencies,
+ # so we may add redundant control inputs.
+ ret.extend([c for c in controller.control_inputs if c not in input_ops])
+ return ret
+
+ def _record_op_seen_by_control_dependencies(self, op):
+ """Record that the given op depends on all registered control dependencies.
+
+ Args:
+ op: An Operation.
+ """
+ for controller in self._control_dependencies_stack:
+ controller.add_op(op)
+
+ def control_dependencies(self, control_inputs):
+ """Returns a context manager that specifies control dependencies.
+
+ Use with the `with` keyword to specify that all operations constructed
+ within the context should have control dependencies on
+ `control_inputs`. For example:
+
+ ```python
+ with g.control_dependencies([a, b, c]):
+ # `d` and `e` will only run after `a`, `b`, and `c` have executed.
+ d = ...
+ e = ...
+ ```
+
+ Multiple calls to `control_dependencies()` can be nested, and in
+ that case a new `Operation` will have control dependencies on the union
+ of `control_inputs` from all active contexts.
+
+ ```python
+ with g.control_dependencies([a, b]):
+ # Ops declared here run after `a` and `b`.
+ with g.control_dependencies([c, d]):
+ # Ops declared here run after `a`, `b`, `c`, and `d`.
+ ```
+
+ *N.B.* The control dependencies context applies *only* to ops that
+ are constructed within the context. Merely using an op or tensor
+ in the context does not add a control dependency. The following
+ example illustrates this point:
+
+ ```python
+ # WRONG
+ def my_func(pred, tensor):
+ t = tf.matmul(tensor, tensor)
+ with tf.control_dependencies([pred]):
+ # The matmul op is created outside the context, so no control
+ # dependency will be added.
+ return t
+
+ # RIGHT
+ def my_func(pred, tensor):
+ with tf.control_dependencies([pred]):
+ # The matmul op is created in the context, so a control dependency
+ # will be added.
+ return tf.matmul(tensor, tensor)
+ ```
+
+ Args:
+ control_inputs: A list of `Operation` or `Tensor` objects, which
+ must be executed or computed before running the operations
+ defined in the context.
+
+ Returns:
+ A context manager that specifies control dependencies for all
+ operations constructed within the context.
+
+ Raises:
+ TypeError: If `control_inputs` is not a list of `Operation` or
+ `Tensor` objects.
+ """
+ # First convert the inputs to ops, and deduplicate them.
+ # NOTE(mrry): Other than deduplication, we do not currently track direct
+ # or indirect dependencies between control_inputs, which may result in
+ # redundant control inputs.
+ control_ops = []
+ current = self._current_control_dependencies()
+ for c in control_inputs:
+ if isinstance(c, Tensor):
+ c = c.op
+ elif not isinstance(c, Operation):
+ raise TypeError("Control input must be Operation or Tensor: %s" % c)
+ if c not in current:
+ control_ops.append(c)
+ current.add(c)
+ return self._ControlDependenciesController(self, control_ops)
+
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
+ def _kernel_label_map(self, op_to_kernel_label_map):
+ """EXPERIMENTAL: A context manager for setting kernel labels.
+
+ This context manager can be used to select particular
+ implementations of kernels within the scope of the context.
+
+ For example:
+
+ with ops.Graph().as_default() as g:
+ f_1 = Foo() # Uses the default registered kernel for the Foo op.
+ with g.kernel_label_map({"Foo": "v_2"}):
+ f_2 = Foo() # Uses the registered kernel with label "v_2"
+ # for the Foo op.
+ with g.kernel_label_map({"Foo": "v_3"}):
+ f_3 = Foo() # Uses the registered kernel with label "v_3"
+ # for the Foo op.
+ with g.kernel_label_map({"Foo": ""}):
+ f_4 = Foo() # Uses the default registered kernel
+ # for the Foo op.
+
+ Args:
+ op_to_kernel_label_map: A dictionary mapping op type strings to
+ kernel label strings.
+
+ Returns:
+ A context manager that sets the kernel label to be used for one or more
+ ops created in that context.
+
+ Raises:
+ TypeError: If op_to_kernel_label_map is not a dictionary mapping
+ strings to strings.
+ """
+ if not isinstance(op_to_kernel_label_map, dict):
+ raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
+ "strings to strings")
+ # The saved_labels dictionary stores any currently-set labels that
+ # will be overridden by this context manager.
+ saved_labels = {}
+ # Install the given label
+ for op_type, label in op_to_kernel_label_map.items():
+ if not (isinstance(op_type, basestring)
+ and isinstance(label, basestring)):
+ raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
+ "strings to strings")
+ try:
+ saved_labels[op_type] = self._op_to_kernel_label_map[op_type]
+ except KeyError:
+ pass
+ self._op_to_kernel_label_map[op_type] = label
+ try:
+ yield # The code within the context runs here.
+ finally:
+ # Remove the labels set for this context, and restore any saved labels.
+ for op_type, label in op_to_kernel_label_map.items():
+ try:
+ self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
+ except KeyError:
+ del self._op_to_kernel_label_map[op_type]
+ # pylint: enable=g-doc-return-or-yield
+
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
+ def gradient_override_map(self, op_type_map):
+ """EXPERIMENTAL: A context manager for overriding gradient functions.
+
+ This context manager can be used to override the gradient function
+ that will be used for ops within the scope of the context.
+
+ For example:
+
+ ```python
+ @tf.RegisterGradient("CustomSquare")
+ def _custom_square_grad(op, inputs):
+ # ...
+
+ with tf.Graph().as_default() as g:
+ c = tf.constant(5.0)
+ s_1 = tf.square(c) # Uses the default gradient for tf.square.
+ with g.gradient_override_map({"Square": "CustomSquare"}):
+ s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the
+ # gradient of s_2.
+ ```
+
+ Args:
+ op_type_map: A dictionary mapping op type strings to alternative op
+ type strings.
+
+ Returns:
+ A context manager that sets the alternative op type to be used for one
+ or more ops created in that context.
+
+ Raises:
+ TypeError: If `op_type_map` is not a dictionary mapping strings to
+ strings.
+ """
+ if not isinstance(op_type_map, dict):
+ raise TypeError("op_type_map must be a dictionary mapping "
+ "strings to strings")
+ # The saved_mappings dictionary stores any currently-set mappings that
+ # will be overridden by this context manager.
+ saved_mappings = {}
+ # Install the given label
+ for op_type, mapped_op_type in op_type_map.items():
+ if not (isinstance(op_type, basestring)
+ and isinstance(mapped_op_type, basestring)):
+ raise TypeError("op_type_map must be a dictionary mapping "
+ "strings to strings")
+ try:
+ saved_mappings[op_type] = self._gradient_override_map[op_type]
+ except KeyError:
+ pass
+ self._gradient_override_map[op_type] = mapped_op_type
+ try:
+ yield # The code within the context runs here.
+ finally:
+ # Remove the labels set for this context, and restore any saved labels.
+ for op_type, mapped_op_type in op_type_map.items():
+ try:
+ self._gradient_override_map[op_type] = saved_mappings[op_type]
+ except KeyError:
+ del self._gradient_override_map[op_type]
+ # pylint: enable=g-doc-return-or-yield
+
+
+def device(dev):
+ """Wrapper for `Graph.device()` using the default graph.
+
+ See [`Graph.name_scope()`](framework.md#Graph.name_scope) for more details.
+
+ Args:
+ device_name_or_function: The device name or function to use in
+ the context.
+
+ Returns:
+ A context manager that specifies the default device to use for newly
+ created ops.
+ """
+ return get_default_graph().device(dev)
+
+
+def name_scope(name):
+ """Wrapper for `Graph.name_scope()` using the default graph.
+
+ See [`Graph.name_scope()`](framework.md#Graph.name_scope) for more details.
+
+ Args:
+ name: A name for the scope.
+
+ Returns:
+ A context manager that installs `name` as a new name scope in the
+ default graph.
+ """
+ return get_default_graph().name_scope(name)
+
+
+def control_dependencies(control_inputs):
+ """Wrapper for `Graph.control_dependencies()` using the default graph.
+
+ See [`Graph.control_dependencies()`](framework.md#Graph.control_dependencies)
+ for more details.
+
+ Args:
+ control_inputs: A list of `Operation` or `Tensor` objects, which
+ must be executed or computed before running the operations
+ defined in the context.
+
+ Returns:
+ A context manager that specifies control dependencies for all
+ operations constructed within the context.
+ """
+ return get_default_graph().control_dependencies(control_inputs)
+
+
+class _DefaultStack(threading.local):
+ """A thread-local stack of objects for providing implicit defaults."""
+
+ def __init__(self):
+ super(_DefaultStack, self).__init__()
+ self.stack = []
+
+ def get_default(self):
+ return self.stack[-1] if len(self.stack) >= 1 else None
+
+ def reset(self):
+ self.stack = []
+
+ @contextlib.contextmanager
+ def get_controller(self, default):
+ """A context manager for manipulating a default stack."""
+ try:
+ self.stack.append(default)
+ yield default
+ finally:
+ assert self.stack[-1] is default
+ self.stack.pop()
+
+
+_default_session_stack = _DefaultStack()
+
+
+def default_session(session):
+ """Python "with" handler for defining a default session.
+
+ This function provides a means of registering a session for handling
+ Tensor.eval() and Operation.run() calls. It is primarily intended for use
+ by session.Session, but can be used with any object that implements
+ the Session.run() interface.
+
+ Use with the "with" keyword to specify that Tensor.eval() and Operation.run()
+ invocations within the scope of a block should be executed by a particular
+ session.
+
+ The default session applies to the current thread only, so it is always
+ possible to inspect the call stack and determine the scope of a default
+ session. If you create a new thread, and wish to use the default session
+ in that thread, you must explicitly add a "with ops.default_session(sess):"
+ block in that thread's function.
+
+ Example:
+ The following code examples are equivalent:
+
+ # 1. Using the Session object directly:
+ sess = ...
+ c = tf.constant(5.0)
+ sess.run(c)
+
+ # 2. Using default_session():
+ sess = ...
+ with ops.default_session(sess):
+ c = tf.constant(5.0)
+ result = c.eval()
+
+ # 3. Overriding default_session():
+ sess = ...
+ with ops.default_session(sess):
+ c = tf.constant(5.0)
+ with ops.default_session(...):
+ c.eval(session=sess)
+
+ Args:
+ session: The session to be installed as the default session.
+
+ Returns:
+ A context manager for the default session.
+ """
+ return _default_session_stack.get_controller(weakref.ref(session))
+
+
+def get_default_session():
+ """Returns the default session for the current thread.
+
+ The returned `Session` will be the innermost session on which a
+ `Session` or `Session.as_default()` context has been entered.
+
+ *N.B.* The default session is a property of the current thread. If you
+ create a new thread, and wish to use the default session in that
+ thread, you must explicitly add a `with sess.as_default():` in that
+ thread's function.
+
+ Returns:
+ The default `Session` being used in the current thread.
+ """
+ ref = _default_session_stack.get_default()
+ if ref is None:
+ # No default session has been registered.
+ return None
+ else:
+ # De-reference ref.
+ ret = ref()
+ if ret is None:
+ # This should never happen with the current session implementations.
+ raise RuntimeError("Default session has been garbage collected.")
+ return ret
+
+
+def _eval_using_default_session(tensors, feed_dict, graph, session=None):
+ """Uses the default session to evaluate one or more tensors.
+
+ Args:
+ tensors: A single Tensor, or a list of Tensor objects.
+ feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
+ numpy ndarrays, TensorProtos, or strings.
+ graph: The graph in which the tensors are defined.
+ session: (Optional) A different session to use to evaluate "tensors".
+
+ Returns:
+ Either a single numpy ndarray if "tensors" is a single tensor; or a list
+ of numpy ndarrays that each correspond to the respective element in
+ "tensors".
+
+ Raises:
+ ValueError: If no default session is available; the default session
+ does not have "graph" as its graph; or if "session" is specified,
+ and it does not have "graph" as its graph.
+ """
+ if session is None:
+ session = get_default_session()
+ if session is None:
+ raise ValueError("Cannot evaluate tensor using eval(): No default "
+ "session is registered. Use 'with "
+ "DefaultSession(sess)' or pass an explicit session to "
+ "eval(session=sess)")
+ if session.graph is not graph:
+ raise ValueError("Cannot use the default session to evaluate tensor: "
+ "the tensor's graph is different from the session's "
+ "graph. Pass an explicit session to "
+ "eval(session=sess).")
+ else:
+ if session.graph is not graph:
+ raise ValueError("Cannot use the given session to evaluate tensor: "
+ "the tensor's graph is different from the session's "
+ "graph.")
+ return session.run(tensors, feed_dict)
+
+
+def _run_using_default_session(operation, feed_dict, graph, session=None):
+ """Uses the default session to run "operation".
+
+ Args:
+ operation: The Operation to be run.
+ feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
+ numpy ndarrays, TensorProtos, or strings.
+ graph: The graph in which "operation" is defined.
+ session: (Optional) A different session to use to run "operation".
+
+ Raises:
+ ValueError: If no default session is available; the default session
+ does not have "graph" as its graph; or if "session" is specified,
+ and it does not have "graph" as its graph.
+ """
+ if session is None:
+ session = get_default_session()
+ if session is None:
+ raise ValueError("Cannot execute operation using Run(): No default "
+ "session is registered. Use 'with "
+ "default_session(sess)' or pass an explicit session to "
+ "Run(session=sess)")
+ if session.graph is not graph:
+ raise ValueError("Cannot use the default session to execute operation: "
+ "the operation's graph is different from the "
+ "session's graph. Pass an explicit session to "
+ "Run(session=sess).")
+ else:
+ if session.graph is not graph:
+ raise ValueError("Cannot use the given session to execute operation: "
+ "the operation's graph is different from the session's "
+ "graph.")
+ session.run(operation, feed_dict)
+
+
+class _DefaultGraphStack(_DefaultStack):
+ """A thread-local stack of objects for providing an implicit default graph."""
+
+ def __init__(self):
+ super(_DefaultGraphStack, self).__init__()
+ self._global_default_graph = None
+
+ def get_default(self):
+ """Override that returns a global default if the stack is empty."""
+ ret = super(_DefaultGraphStack, self).get_default()
+ if ret is None:
+ ret = self._GetGlobalDefaultGraph()
+ return ret
+
+ def _GetGlobalDefaultGraph(self):
+ if self._global_default_graph is None:
+ # TODO(mrry): Perhaps log that the default graph is being used, or set
+ # provide some other feedback to prevent confusion when a mixture of
+ # the global default graph and an explicit graph are combined in the
+ # same process.
+ self._global_default_graph = Graph()
+ return self._global_default_graph
+
+ def reset(self):
+ super(_DefaultGraphStack, self).reset()
+ self._global_default_graph = None
+
+_default_graph_stack = _DefaultGraphStack()
+
+
+def reset_default_graph():
+ """Clears the default graph stack and resets the global default graph.
+
+ *N.B.* The default graph is a property of the current thread. This
+ function applies only to the current thread.
+ """
+ _default_graph_stack.reset()
+
+
+def get_default_graph():
+ """Returns the default graph for the current thread.
+
+ The returned graph will be the innermost graph on which a
+ `Graph.as_default()` context has been entered, or a global default
+ graph if none has been explicitly created.
+
+ *N.B.* The default graph is a property of the current thread. If you
+ create a new thread, and wish to use the default graph in that
+ thread, you must explicitly add a `with g.as_default():` in that
+ thread's function.
+
+ Returns:
+ The default `Graph` being used in the current thread.
+ """
+ return _default_graph_stack.get_default()
+
+
+def _get_graph_from_inputs(op_input_list, graph=None):
+ """Returns the appropriate graph to use for the given inputs.
+
+ This library method provides a consistent algorithm for choosing the graph
+ in which an Operation should be constructed:
+
+ 1. If the "graph" is specified explicitly, we validate that all of the inputs
+ in "op_input_list" are compatible with that graph.
+ 2. Otherwise, we attempt to select a graph from the first Operation-
+ or Tensor-valued input in "op_input_list", and validate that all other
+ such inputs are in the same graph.
+ 3. If the graph was not specified and it could not be inferred from
+ "op_input_list", we attempt to use the default graph.
+
+ Args:
+ op_input_list: A list of inputs to an operation, which may include Tensor
+ and Operation objects.
+ graph: (Optional) The explicit graph to use.
+
+ Raises:
+ TypeError: If op_input_list is not a list or tuple, or if graph is not a
+ Graph.
+ ValueError: If a graph is explicitly passed and not all inputs are from it,
+ or if the inputs are from multiple graphs, or we could not find a graph
+ and there was no default graph.
+
+ Returns:
+ The appropriate graph to use for the given inputs.
+ """
+ if not isinstance(op_input_list, (list, tuple)):
+ raise TypeError("The op_input_list must be a list or tuple")
+
+ # 1. If the graph is specified explicitly, we validate that all of the inputs
+ # are compatible with that graph.
+ if graph is not None:
+ if not isinstance(graph, Graph):
+ raise TypeError("Input graph needs to be a Graph: %s" % graph)
+ for op_input in op_input_list:
+ if isinstance(op_input, Operation):
+ if op_input.graph is not graph:
+ raise ValueError("Operation %s is not from the passed-in graph"
+ % op_input)
+ elif isinstance(op_input, Tensor):
+ if op_input.graph is not graph:
+ raise ValueError("Tensor %s is not from the passed-in graph"
+ % op_input)
+ return graph
+
+ # 2. Otherwise, we attempt to select a graph from one of the Operation-
+ # or Tensor-valued inputs.
+ original_input = None
+ for op_input in op_input_list:
+ if isinstance(op_input, (Operation, Tensor)):
+ if original_input is None:
+ original_input = op_input
+ else:
+ assert_same_graph([original_input, op_input])
+ if original_input is not None:
+ return original_input.graph
+
+ # 3. If all else fails, we use the default graph, which is always there.
+ return get_default_graph()
+
+
+class GraphKeys(object):
+ """Standard names to use for graph collections.
+
+ The standard library uses various well-known names to collect and
+ retrieve values associated with a graph. For example, the
+ `tf.Optimizer` subclasses default to optimizing the variables
+ collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
+ specified, but it is also possible to pass an explicit list of
+ variables.
+
+ The following standard keys are defined:
+
+ * `VARIABLES`: the `Variable` objects that comprise a model, and
+ must be saved and restored together. See
+ [`tf.all_variables()`](state_ops.md#all_variables) for more details.
+ * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will
+ be trained by an optimizer. See
+ [`tf.trainable_variables()`](state_ops.md#trainable_variables)
+ for more details.
+ * `SUMMARIES`: the summary `Tensor` objects that have been created
+ in the graph. See [`tf.merge_all_summaries()`](train.md#merge_all_summaries)
+ for more details.
+ * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
+ produce input for a computation. See
+ [`tf.start_queue_runners()`](train.md#start_queue_runners) for more details.
+ """
+
+ # Key to collect variables.Variable objects that must be saved and restored
+ # by the model.
+ VARIABLES = "variables"
+ # Key to collect variables.Variable objects that will be trained by the
+ # optimizers.
+ TRAINABLE_VARIABLES = "trainable_variables"
+ # Key to collect summaries.
+ SUMMARIES = "summaries"
+ # Key to collect QueueRunners.
+ QUEUE_RUNNERS = "queue_runners"
+ # Key to collect table initializers.
+ TABLE_INITIALIZERS = "table_initializer"
+
+
+def add_to_collection(name, value):
+ """Wrapper for `Graph.add_to_collection()` using the default graph.
+
+ See [`Graph.add_to_collection()`](framework.md#Graph.add_to_collection)
+ for more details.
+
+ Args:
+ name: The key for the collection. For example, the `GraphKeys` class
+ contains many standard names for collections.
+ value: The value to add to the collection.
+ """
+ get_default_graph().add_to_collection(name, value)
+
+
+def get_collection(key, scope=None):
+ """Wrapper for `Graph.get_collection()` using the default graph.
+
+ See [`Graph.get_collection()`](framework.md#Graph.get_collection)
+ for more details.
+
+ Args:
+ key: The key for the collection. For example, the `GraphKeys` class
+ contains many standard names for collections.
+ scope: (Optional.) If supplied, the resulting list is filtered to include
+ only items whose name begins with this string.
+
+ Returns:
+ The list of values in the collection with the given `name`, or
+ an empty list if no value has been added to that collection. The
+ list contains the values in the order under which they were
+ collected.
+ """
+ return get_default_graph().get_collection(key, scope)
+
+
+# pylint: disable=g-doc-return-or-yield
+@contextlib.contextmanager
+def op_scope(values, name, default_name):
+ """Returns a context manager for use when defining a Python op.
+
+ This context manager validates that the given `values` are from the
+ same graph, ensures that that graph is the default graph, and pushes a
+ name scope.
+
+ For example, to define a new Python op called `my_op`:
+
+ ```python
+ def my_op(a, b, c, name=None):
+ with tf.op_scope([a, b, c], name, "MyOp") as scope:
+ a = tf.convert_to_tensor(a, name="a")
+ b = tf.convert_to_tensor(b, name="b")
+ c = tf.convert_to_tensor(c, name="c")
+ # Define some computation that uses `a`, `b`, and `c`.
+ return foo_op(..., name=scope)
+ ```
+
+ Args:
+ values: The list of `Tensor` arguments that are passed to the op function.
+ name: The name argument that is passed to the op function.
+ default_name: The default name to use if the `name` argument is `None`.
+
+ Returns:
+ A context manager for use in defining a Python op.
+ """
+ g = _get_graph_from_inputs(values)
+ n = default_name if name is None else name
+ with g.as_default(), g.name_scope(n) as scope:
+ yield scope
+# pylint: enable=g-doc-return-or-yield
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
new file mode 100644
index 0000000000..a406c5e56e
--- /dev/null
+++ b/tensorflow/python/framework/ops_test.py
@@ -0,0 +1,825 @@
+"""Tests for tensorflow.python.framework.ops."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_kernel_label_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.platform import googletest
+
+
+class TensorTest(test_util.TensorFlowTestCase):
+
+ def testShape(self):
+ op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(),
+ [], [types.float32])
+ t = op.outputs[0]
+ self.assertEquals(tensor_shape.unknown_shape(), t.get_shape())
+ t.set_shape([1, 2, 3])
+ self.assertEquals([1, 2, 3], t.get_shape())
+
+
+class NodeDefConstructorTest(test_util.TensorFlowTestCase):
+
+ def testNoArgs(self):
+ nodedef = ops._NodeDef("noop", "bar")
+ self.assertProtoEquals("op: 'noop' name: 'bar'", nodedef)
+
+ def testArgs(self):
+ nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*")
+ self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'",
+ nodedef)
+ nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j"))
+ self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)
+
+
+# NOTE(mrry): Dummy shape registrations for ops used in the tests.
+ops.RegisterShape("a")(None)
+ops.RegisterShape("b")(None)
+ops.RegisterShape("c")(None)
+ops.RegisterShape("add")(None)
+ops.RegisterShape("an_op")(None)
+ops.RegisterShape("const")(None)
+ops.RegisterShape("copy")(None)
+ops.RegisterShape("foo")(None)
+ops.RegisterShape("identity")(None)
+ops.RegisterShape("mul")(None)
+ops.RegisterShape("nonrefop")(None)
+ops.RegisterShape("noop")(None)
+ops.RegisterShape("refop")(None)
+
+
+def _apply_op(g, *args, **kwargs):
+ op = g.create_op(*args, **kwargs)
+ if len(op.outputs) == 1:
+ return op.outputs[0]
+ else:
+ return op.outputs
+
+
+class OperationTest(test_util.TensorFlowTestCase):
+
+ def testNoInputs(self):
+ op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(),
+ [],
+ [types.float32, types.string])
+ self.assertEquals(2, len(op.values()))
+ self.assertEquals(0, len(op.inputs))
+ self.assertEquals("myop", op.name)
+
+ float_t, label_str_t = op.values()
+ self.assertEquals(types.float32, float_t.dtype)
+ self.assertEquals(op, float_t.op)
+ self.assertEquals(0, float_t._value_index)
+ self.assertEquals(0, len(float_t._consumers))
+ self.assertEquals("myop", float_t._as_node_def_input())
+
+ self.assertEquals(types.string, label_str_t.dtype)
+ self.assertEquals(op, label_str_t.op)
+ self.assertEquals(1, label_str_t._value_index)
+ self.assertEquals(0, len(label_str_t._consumers))
+ self.assertEquals("myop:1", label_str_t._as_node_def_input())
+
+ self.assertProtoEquals("op:'noop' name:'myop'", op.node_def)
+
+ def testNoOutputs(self):
+ g = ops.Graph()
+ op1 = ops.Operation(
+ ops._NodeDef("noop", "myop1"), g, [], [types.float32])
+ float_t, = op1.values()
+ op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [float_t], [])
+ self.assertEquals(0, len(op2.values()))
+ self.assertEquals(1, len(op2.inputs))
+ self.assertIs(float_t, op2.inputs[0])
+
+ self.assertEquals(1, len(float_t._consumers))
+ self.assertEquals(op2, float_t._consumers[0])
+
+ self.assertProtoEquals("op:'noop' name:'myop1'", op1.node_def)
+ self.assertProtoEquals("op:'reop' name:'myop2' input:'myop1'",
+ op2.node_def)
+
+ def testInputsAndOutputs(self):
+ g = ops.Graph()
+ op1 = ops.Operation(
+ ops._NodeDef("noop", "myop1"), g, [], [types.float32])
+ self.assertEquals(1, len(op1.values()))
+ float1_t, = op1.values()
+
+ op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g,
+ [], [types.float32, types.string])
+ self.assertEquals(2, len(op2.values()))
+ float2_t, label2_str_t = op2.values()
+
+ # Note that we consume label2_str_t twice here.
+ op3 = ops.Operation(ops._NodeDef("add", "myop3"), g,
+ [float1_t, label2_str_t, label2_str_t],
+ [types.float32, types.int32])
+ self.assertEquals(2, len(op3.values()))
+
+ self.assertEquals(1, len(float1_t._consumers))
+ self.assertEquals(op3, float1_t._consumers[0])
+
+ self.assertEquals(0, len(float2_t._consumers))
+
+ self.assertEquals(2, len(label2_str_t._consumers))
+ self.assertEquals(op3, label2_str_t._consumers[0])
+ self.assertEquals(op3, label2_str_t._consumers[1])
+
+ self.assertProtoEquals("""
+ op:'add' name:'myop3'
+ input:'myop1' input:'myop2:1' input:'myop2:1'
+ """, op3.node_def)
+
+ def testDeviceObject(self):
+ op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], [])
+ op._set_device("/job:goo/device:GPU:0")
+ self.assertProtoEquals(
+ "op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ",
+ op.node_def)
+ op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], [])
+ op._set_device(pydev.Device(job="muu", device_type="CPU", device_index=0))
+ self.assertProtoEquals(
+ "op:'noop' name:'op2' device:'/job:muu/device:CPU:0'",
+ op.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ op1 = ops.Operation(ops._NodeDef("noop", "op1"), g, [],
+ [types.float32_ref, types.float32])
+ self.assertProtoEquals("op:'noop' name:'op1'",
+ op1.node_def)
+ ref_t, nonref_t = op1.values()
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ op2 = ops.Operation(
+ ops._NodeDef("refop", "op2"), g, [ref_t, nonref_t], [],
+ input_types=[types.float32_ref, types.float32])
+ self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'",
+ op2.node_def)
+ op3 = ops.Operation(
+ ops._NodeDef("nonrefop", "op3"), g, [ref_t, nonref_t], [])
+ self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'",
+ op3.node_def)
+
+ def testInvalidNames(self):
+ g = ops.Graph()
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", ""), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "_invalid"), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "-invalid"), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "/invalid"), g)
+
+ def testShapeFunctionAbsence(self):
+ def _test():
+ pass
+ g = ops.Graph()
+ with self.assertRaises(RuntimeError):
+ g.create_op("shapeless_op", [], [types.float32])
+
+ def testNoShapeFunction(self):
+ g = ops.Graph()
+ op = ops.Operation(ops._NodeDef("op", "an_op"), g,
+ output_types = [types.float32])
+ self.assertEquals(tensor_shape.unknown_shape(),
+ _apply_op(g, "an_op", [], [types.float32]).get_shape())
+
+class CreateOpTest(test_util.TensorFlowTestCase):
+
+ def testNodeDefArgs(self):
+ g = ops.Graph()
+ op1 = g.create_op("const", [], [types.float32], None, name="myop1")
+ with g.device("/device:GPU"):
+ op2 = g.create_op("add",
+ [],
+ [types.float32, types.string], None,
+ name="myop2")
+ op3 = g.create_op(
+ "foo",
+ [op1.values()[0], op2.values()[1], op2.values()[0]],
+ [types.float32, types.int32], None,
+ name="myop3")
+ self.assertEquals(None, op1.device)
+ self.assertEquals("/device:GPU", op2.device)
+ self.assertEquals(None, op3.device)
+ self.assertProtoEquals("name:'myop1' op:'const'", op1.node_def)
+ self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'",
+ op2.node_def)
+ self.assertProtoEquals(
+ "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'",
+ op3.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ op1 = g.create_op("noop", [],
+ [types.float32_ref, types.float32], name="op1")
+ self.assertProtoEquals("op:'noop' name:'op1'", op1.node_def)
+ ref_t, nonref_t = op1.values()
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ op2 = g.create_op("refop", [ref_t, nonref_t], [],
+ input_types=[types.float32_ref, types.float32],
+ name="op2")
+ self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'",
+ op2.node_def)
+ op3 = g.create_op("nonrefop", [ref_t, nonref_t], [], name="op3")
+ self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'",
+ op3.node_def)
+
+ def testFinalized(self):
+ g = ops.Graph()
+ g.finalize()
+ with self.assertRaises(RuntimeError):
+ g.create_op("const", [], [types.float32], None, name="myop1")
+
+
+class ApplyOpTest(test_util.TensorFlowTestCase):
+
+ def testNodeDefArgs(self):
+ g = ops.Graph()
+ t1 = _apply_op(g, "const", [], [types.float32], name="myop1")
+ with g.device("/device:GPU"):
+ t2 = _apply_op(g, "add",
+ [],
+ [types.float32, types.string],
+ name="myop2")
+ t3 = _apply_op(g, "foo", [t1, t2[1], t2[0]],
+ [types.float32, types.int32], name="myop3")
+ self.assertTrue(isinstance(t1, ops.Tensor))
+ self.assertTrue(isinstance(t2, list))
+ self.assertTrue(isinstance(t3, list))
+ self.assertTrue(isinstance(t3[0], ops.Tensor))
+ self.assertEquals("myop1", t1._as_node_def_input())
+ self.assertEquals("myop2", t2[0]._as_node_def_input())
+ self.assertEquals("myop2:1", t2[1]._as_node_def_input())
+ self.assertEquals("myop3", t3[0]._as_node_def_input())
+ # Validate that we got the right ops as well
+ self.assertProtoEquals("name:'myop1' op:'const'", t1.op.node_def)
+ self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'",
+ t2[0].op.node_def)
+ self.assertProtoEquals(
+ "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'",
+ t3[0].op.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ ref_t, nonref_t = _apply_op(
+ g, "noop", [], [types.float32_ref, types.float32], name="op1")
+ self.assertProtoEquals("op:'noop' name:'op1'", ref_t.op.node_def)
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ out_2 = _apply_op(g, "refop", [ref_t, nonref_t], [types.int32],
+ input_types=[types.float32_ref, types.float32],
+ name="op2")
+ self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'",
+ out_2.op.node_def)
+ out_3 = _apply_op(g, "nonrefop", [ref_t, nonref_t], [types.int32],
+ name="op3")
+ self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'",
+ out_3.op.node_def)
+
+
+class NameStackTest(test_util.TensorFlowTestCase):
+
+ def testBasics(self):
+ g = ops.Graph()
+ self.assertEquals("foo", g.unique_name("foo"))
+ self.assertEquals("foo_1", g.unique_name("foo"))
+ self.assertEquals("foo_2", g.unique_name("foo"))
+ self.assertEquals("foo_1_1", g.unique_name("foo_1"))
+ self.assertEquals("foo_1_2", g.unique_name("foo_1"))
+ self.assertEquals("foo_1_2_1", g.unique_name("foo_1_2"))
+ with g.name_scope("bar"):
+ self.assertEquals("bar/foo", g.unique_name("foo"))
+ self.assertEquals("bar/foo_1", g.unique_name("foo"))
+ with g.name_scope(None):
+ self.assertEquals("foo_3", g.unique_name("foo"))
+ with g.name_scope("baz"):
+ self.assertEquals("bar/baz/foo", g.unique_name("foo"))
+ self.assertEquals("bar/baz/foo_1", g.unique_name("foo"))
+ with g.name_scope("baz"):
+ self.assertEquals("bar/baz_1/foo", g.unique_name("foo"))
+ self.assertEquals("bar/baz_1/foo_1", g.unique_name("foo"))
+ with g.name_scope("quux"):
+ self.assertEquals("quux/foo", g.unique_name("foo"))
+ with g.name_scope("bar"):
+ with g.name_scope("baz"):
+ self.assertEquals("bar_1/baz/foo", g.unique_name("foo"))
+ self.assertEquals("foo_4", g.unique_name("foo"))
+ self.assertEquals("bar_2", g.unique_name("bar"))
+
+ def testOutOfOrderUniqueName(self):
+ g = ops.Graph()
+ self.assertEquals("foo_2", g.unique_name("foo_2"))
+ self.assertEquals("foo", g.unique_name("foo"))
+ self.assertEquals("foo_1", g.unique_name("foo"))
+ self.assertEquals("foo_3", g.unique_name("foo"))
+
+
+class NameTest(test_util.TensorFlowTestCase):
+
+ def testGenerateName(self):
+ g = ops.Graph()
+ op0 = g.create_op("const", [], [types.float32, types.float32])
+ self.assertEquals("const", op0.name)
+ self.assertEquals("const:0", op0.outputs[0].name)
+ self.assertEquals("const:1", op0.outputs[1].name)
+
+ op1 = g.create_op("const", [], [types.float32])
+ self.assertEquals("const_1", op1.name)
+ self.assertEquals("const_1:0", op1.outputs[0].name)
+
+ op2 = g.create_op("const", [], [types.float32], name="my_op")
+ self.assertEquals("my_op", op2.name)
+ self.assertEquals("my_op:0", op2.outputs[0].name)
+
+ def testname_scope(self):
+ g = ops.Graph()
+
+ with g.name_scope("foo") as foo:
+ self.assertEquals(foo, "foo/")
+ with g.name_scope("foo2") as foo2:
+ self.assertEquals(foo2, "foo/foo2/")
+ with g.name_scope(None) as empty1:
+ self.assertEquals(empty1, "")
+ with g.name_scope("foo3") as foo3:
+ self.assertEquals(foo3, "foo3/")
+ with g.name_scope("") as empty2:
+ self.assertEquals(empty2, "")
+
+ self.assertEquals("const",
+ g.create_op("const", [], [types.float32]).name)
+ with g.name_scope("bar") as scope:
+ self.assertEquals("bar/const",
+ g.create_op("const", [], [types.float32]).name)
+ self.assertEquals("bar/const_1",
+ g.create_op("const", [], [types.float32]).name)
+ # If you use the value from "with .. as", that values is used as-is.
+ self.assertEquals(
+ "bar",
+ g.create_op("const", [], [types.float32], name=scope).name)
+ with g.name_scope("baz") as scope:
+ with g.name_scope("quux"):
+ self.assertEquals("baz/quux/const",
+ g.create_op("const", [], [types.float32]).name)
+ # If you use the value from the enclosing "with .. as", nothing is pushed.
+ with g.name_scope(scope):
+ self.assertEquals("baz/const",
+ g.create_op("const", [], [types.float32]).name)
+ self.assertEquals("baz",
+ g.create_op("const", [], [types.float32],
+ name=scope).name)
+ self.assertEquals("trailing",
+ g.create_op("const", [], [types.float32],
+ name="trailing/").name)
+ with g.name_scope("bar"):
+ self.assertEquals("bar_1/const",
+ g.create_op("const", [], [types.float32]).name)
+ with g.name_scope("bar/"):
+ self.assertEquals("bar/const_2",
+ g.create_op("const", [], [types.float32]).name)
+
+
+class DeviceTest(test_util.TensorFlowTestCase):
+
+ def testNoDevice(self):
+ g = ops.Graph()
+ op = g.create_op("an_op", [], [types.float32])
+ self.assertEqual(None, op.device)
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op" }
+ """, gd)
+
+ def testDevicePartialString(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op" device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testDeviceFull(self):
+ g = ops.Graph()
+ with g.device(pydev.Device(job="worker", replica=2, task=0,
+ device_type="CPU",
+ device_index=3)):
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2/task:0/device:CPU:3" }
+ """, gd)
+
+ def testNesting(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device("/job:worker/replica:3/task:0"):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/replica:3/task:0" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testNestingString(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device("/job:worker/replica:3/task:0"):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/replica:3/task:0" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testNestingOverrideGpuCpu(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2/device:CPU:1"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device("/job:worker/replica:2/device:GPU:2"):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/replica:2/device:GPU:2" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ """, gd)
+
+ def testNestingWithMergeDeviceFunction(self):
+ g = ops.Graph()
+
+ with g.device(pydev.merge_device("/device:GPU:0")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device("/job:worker")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device("/device:CPU:0")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device("/job:ps")):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(pydev.merge_device(None)):
+ g.create_op("an_op", [], [types.float32])
+
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/device:GPU:0" }
+ node { name: "an_op_1" op: "an_op"
+ device: "/job:worker/device:GPU:0" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/device:CPU:0" }
+ node { name: "an_op_3" op: "an_op"
+ device: "/job:ps/device:CPU:0" }
+ node { name: "an_op_4" op: "an_op"
+ device: "/job:ps/device:CPU:0" }
+ """, gd)
+
+ def testNoneClearsDefault(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2/device:CPU:1"):
+ g.create_op("an_op", [], [types.float32])
+ with g.device(None):
+ g.create_op("an_op", [], [types.float32])
+ g.create_op("an_op", [], [types.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "an_op" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ node { name: "an_op_1" op: "an_op" }
+ node { name: "an_op_2" op: "an_op"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ """, gd)
+
+
+class ObjectWithName(object):
+
+ def __init__(self, name):
+ self._name = name
+
+ @property
+ def name(self):
+ return self._name
+
+
+class CollectionTest(test_util.TensorFlowTestCase):
+
+ def testadd_to_collection(self):
+ g = ops.Graph()
+ g.add_to_collection("key", 12)
+ g.add_to_collection("other", "foo")
+ g.add_to_collection("key", 34)
+
+ # Note that only blank1 is returned.
+ g.add_to_collection("blah", 27)
+ blank1 = ObjectWithName("prefix/foo")
+ g.add_to_collection("blah", blank1)
+ blank2 = ObjectWithName("junk/foo")
+ g.add_to_collection("blah", blank2)
+
+ self.assertEquals(["foo"], g.get_collection("other"))
+ self.assertEquals([12, 34], g.get_collection("key"))
+ self.assertEquals([], g.get_collection("nothing"))
+ self.assertEquals([27, blank1, blank2], g.get_collection("blah"))
+ self.assertEquals([blank1], g.get_collection("blah", "prefix"))
+
+ def testDefaulGraph(self):
+ with ops.Graph().as_default():
+ ops.add_to_collection("key", 90)
+ ops.add_to_collection("key", 100)
+ # Collections are ordered.
+ self.assertEquals([90, 100], ops.get_collection("key"))
+
+
+def an_op(g):
+ return _apply_op(g, "an_op", [], [types.float32])
+
+
+ops.NoGradient("an_op")
+
+
+def copy_op(x):
+ return _apply_op(x.graph, "copy", [x], [x.dtype])
+
+
+@ops.RegisterGradient("copy")
+def _CopyGrad(op, x_grad):
+ _ = op
+ return x_grad
+
+
+@ops.RegisterGradient("copy_override")
+def _CopyOverrideGrad(op, x_grad):
+ _ = op
+ return x_grad
+
+
+class RegistrationTest(test_util.TensorFlowTestCase):
+
+ def testRegisterGradients(self):
+ g = ops.Graph()
+ x = an_op(g)
+ y = copy_op(x)
+ fn = ops.get_gradient_function(y.op)
+ self.assertEquals(_CopyGrad, fn)
+
+ def testOverrideGradients(self):
+ g = ops.Graph()
+ x = an_op(g)
+ with g.gradient_override_map({"copy": "copy_override"}):
+ y = copy_op(x)
+ fn = ops.get_gradient_function(y.op)
+ self.assertEquals(_CopyOverrideGrad, fn)
+
+ def testNonExistentOverride(self):
+ g = ops.Graph()
+ x = an_op(g)
+ with g.gradient_override_map({"copy": "unknown_override"}):
+ y = copy_op(x)
+ with self.assertRaisesRegexp(LookupError, "unknown_override"):
+ fn = ops.get_gradient_function(y.op)
+
+
+class ComparisonTest(test_util.TensorFlowTestCase):
+
+ def testMembershipAllowed(self):
+ g = ops.Graph()
+ t1 = _apply_op(g, "const", [], [types.float32], name="myop1")
+ t2 = _apply_op(g, "const", [], [types.float32], name="myop2")
+ self.assertTrue(isinstance(t1, ops.Tensor))
+ self.assertTrue(isinstance(t2, ops.Tensor))
+ self.assertTrue(t1 in [t1])
+ self.assertTrue(t1 not in [t2])
+
+
+class ControlDependenciesTest(test_util.TensorFlowTestCase):
+
+ def testBasic(self):
+ g = ops.Graph()
+ a = _apply_op(g, "const", [], [types.float32])
+ b = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a]):
+ c = _apply_op(g, "const", [], [types.float32])
+ d = _apply_op(g, "identity", [b], [types.float32])
+ e = _apply_op(g, "identity", [c], [types.float32])
+
+ self.assertEqual(c.op.control_inputs, [a.op])
+ self.assertEqual(d.op.control_inputs, [a.op])
+ # e should be dominated by c.
+ self.assertEqual(e.op.control_inputs, [])
+
+ def testNested(self):
+ g = ops.Graph()
+ a_1 = _apply_op(g, "const", [], [types.float32])
+ a_2 = _apply_op(g, "const", [], [types.float32])
+ a_3 = _apply_op(g, "const", [], [types.float32])
+ a_4 = _apply_op(g, "const", [], [types.float32])
+
+ with g.control_dependencies([a_1, a_2, a_3, a_4]):
+ b_1 = _apply_op(g, "const", [], [types.float32])
+
+ with g.control_dependencies([a_1]):
+ with g.control_dependencies([a_2]):
+ with g.control_dependencies([a_3]):
+ with g.control_dependencies([a_4]):
+ b_2 = _apply_op(g, "const", [], [types.float32])
+
+ self.assertItemsEqual(
+ [a_1.op, a_2.op, a_3.op, a_4.op], b_1.op.control_inputs)
+ self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
+
+ def testComplex(self):
+ g = ops.Graph()
+
+ # Usage pattern:
+ # * Nodes a_i are constants defined at the outermost scope, and are used
+ # as control inputs for the ith nested scope.
+ # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
+ # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
+ # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
+ # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
+
+ a_1 = _apply_op(g, "const", [], [types.float32])
+ a_2 = _apply_op(g, "const", [], [types.float32])
+ a_3 = _apply_op(g, "const", [], [types.float32])
+ a_4 = _apply_op(g, "const", [], [types.float32])
+
+ with g.control_dependencies([a_1]):
+ b_1 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_1 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_1 = _apply_op(g, "mul", [b_1, c_1], [types.float32])
+ e_1 = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a_2]):
+ b_2 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_2 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_2 = _apply_op(g, "mul", [b_2, c_2], [types.float32])
+ e_2 = _apply_op(g, "mul", [e_1, e_1], [types.float32])
+ with g.control_dependencies([a_3]):
+ b_3 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_3 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_3 = _apply_op(g, "mul", [b_3, c_3], [types.float32])
+ e_3 = _apply_op(g, "mul", [e_2, e_2], [types.float32])
+ with g.control_dependencies([a_4]):
+ b_4 = _apply_op(g, "mul", [a_3, a_4], [types.float32])
+ c_4 = _apply_op(g, "mul", [a_1, b_1], [types.float32])
+ d_4 = _apply_op(g, "mul", [b_4, c_4], [types.float32])
+ e_4 = _apply_op(g, "mul", [e_3, e_3], [types.float32])
+
+ self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
+
+ self.assertItemsEqual([], c_1.op.control_inputs)
+ self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
+ self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
+ self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
+
+ self.assertItemsEqual([], d_1.op.control_inputs)
+ self.assertItemsEqual([], d_2.op.control_inputs)
+ self.assertItemsEqual([], d_3.op.control_inputs)
+ self.assertItemsEqual([], d_4.op.control_inputs)
+
+ self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
+ self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
+ self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
+ self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
+
+ def testRepeatedDependency(self):
+ g = ops.Graph()
+ a = g.create_op("foo", [], [types.float32, types.float32])
+ a_0, a_1 = a.outputs
+ with g.control_dependencies([a_0]):
+ b = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a_1]):
+ c = _apply_op(g, "const", [], [types.float32])
+
+ self.assertEqual(b.op.control_inputs, [a])
+ self.assertEqual(c.op.control_inputs, [a])
+
+ def testNoControlDependencyWithDataDependency(self):
+ g = ops.Graph()
+ a = _apply_op(g, "const", [], [types.float32])
+ with g.control_dependencies([a]):
+ b = _apply_op(g, "identity", [a], [types.float32])
+
+ self.assertEqual(b.op.control_inputs, [])
+
+
+class GraphTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ ops.reset_default_graph()
+
+ def _AssertDefault(self, expected):
+ self.assertIs(expected, ops.get_default_graph())
+
+ def testGraphContextManager(self):
+ g0 = ops.Graph()
+ with g0.as_default() as g1:
+ self.assertIs(g0, g1)
+
+ def testDefaultGraph(self):
+ orig = ops.get_default_graph()
+ self._AssertDefault(orig)
+ g0 = ops.Graph()
+ self._AssertDefault(orig)
+ context_manager_0 = g0.as_default()
+ self._AssertDefault(orig)
+ with context_manager_0 as g0:
+ self._AssertDefault(g0)
+ with ops.Graph().as_default() as g1:
+ self._AssertDefault(g1)
+ self._AssertDefault(g0)
+ self._AssertDefault(orig)
+
+ def testAsGraphElementConversions(self):
+ class ConvertibleObj(object):
+
+ def _as_graph_element(self):
+ return "const:0"
+
+ class NonConvertibleObj(object):
+
+ pass
+
+ g = ops.Graph()
+ a = _apply_op(g, "const", [], [types.float32])
+ self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
+ with self.assertRaises(TypeError):
+ g.as_graph_element(NonConvertibleObj())
+
+ def testAssertSameGraph(self):
+ g0 = ops.Graph()
+ a = g0.create_op("a", [], [types.float32])
+ b = g0.create_op("b", [], [types.float32])
+ ops.assert_same_graph([a, b])
+ ops.assert_same_graph([a, b], g0)
+ g1 = ops.Graph()
+ c = g1.create_op("c", [], [types.float32])
+ self.assertRaises(ValueError, ops.assert_same_graph, [a, b, c])
+ self.assertRaises(ValueError, ops.assert_same_graph, [c], g0)
+ self.assertRaises(ValueError, ops.assert_same_graph, [a], g1)
+
+ sparse = ops.SparseTensor(
+ _apply_op(g0, "const", [], [types.int64]),
+ _apply_op(g0, "const", [], [types.float32]),
+ _apply_op(g0, "const", [], [types.int64]))
+ ops.assert_same_graph([sparse, a, b])
+ ops.assert_same_graph([sparse, a, b], g0)
+ self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c])
+ self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c], g1)
+
+ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
+
+
+class KernelLabelTest(test_util.TensorFlowTestCase):
+
+ def testNoLabel(self):
+ with self.test_session():
+ self.assertAllEqual("My label is: default",
+ test_kernel_label_op.kernel_label().eval())
+
+ def testLabelMap(self):
+ with self.test_session() as sess:
+ default_1 = test_kernel_label_op.kernel_label()
+ # pylint: disable=protected-access
+ with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
+ overload_1_1 = test_kernel_label_op.kernel_label()
+ with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
+ overload_2 = test_kernel_label_op.kernel_label()
+ with sess.graph._kernel_label_map({"KernelLabel": ""}):
+ default_2 = test_kernel_label_op.kernel_label()
+ overload_1_2 = test_kernel_label_op.kernel_label()
+ # pylint: enable=protected-access
+ default_3 = test_kernel_label_op.kernel_label()
+
+ self.assertAllEqual("My label is: default", default_1.eval())
+ self.assertAllEqual("My label is: default", default_2.eval())
+ self.assertAllEqual("My label is: default", default_3.eval())
+ self.assertAllEqual("My label is: overload_1", overload_1_1.eval())
+ self.assertAllEqual("My label is: overload_1", overload_1_2.eval())
+ self.assertAllEqual("My label is: overload_2", overload_2.eval())
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
new file mode 100644
index 0000000000..5c1b4462d5
--- /dev/null
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -0,0 +1,678 @@
+#include "tensorflow/python/framework/python_op_gen.h"
+
+#include <stdio.h>
+#include <unordered_map>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace {
+
+const int kRightMargin = 78;
+
+bool IsPythonReserved(const string& s) {
+ static const std::set<string>* const kPythonReserved = new std::set<string>(
+ {// Keywords in Python, from:
+ // import keyword
+ // print keyword.kwlist
+ "and", "as", "assert", "break", "class", "continue", "def", "del",
+ "elif", "else", "except", "exec", "finally", "for", "from", "global",
+ "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",
+ "raise", "return", "try", "while", "with", "yield",
+ // Built-in functions and types in Python, from:
+ // [x for x in dir(__builtins__) if not x[0].islower()]
+ "ArithmeticError", "AssertionError", "AttributeError", "BaseException",
+ "BufferError", "BytesWarning", "DeprecationWarning", "EOFError",
+ "Ellipsis", "EnvironmentError", "Exception", "False",
+ "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError",
+ "ImportError", "ImportWarning", "IndentationError", "IndexError",
+ "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError",
+ "NameError", "None", "NotImplemented", "NotImplementedError", "OSError",
+ "OverflowError", "PendingDeprecationWarning", "ReferenceError",
+ "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration",
+ "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError",
+ "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError",
+ "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
+ "UnicodeWarning", "UserWarning", "ValueError", "Warning",
+ "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
+ "__package__",
+ // Imports and symbols used in the generated code:
+ "_op_def_lib", "text_format", "op_def_pb2", "op_def_library", "ops"});
+
+ return kPythonReserved->count(s) > 0;
+}
+
+// Add a _ to the end of s if necessary to avoid a Python keyword or built-in.
+string AvoidPythonReserved(const string& s) {
+ if (IsPythonReserved(s)) return strings::StrCat(s, "_");
+ return s;
+}
+
+// Indent the first line by "initial" spaces and all following lines
+// by "rest" spaces.
+string Indent(int initial, int rest, StringPiece in) {
+ // TODO(josh11b): Also word-wrapping?
+ string copy(in.data(), in.size());
+ str_util::StripTrailingWhitespace(&copy);
+ std::vector<string> v = str_util::Split(copy, '\n');
+
+ string result;
+ bool first = true;
+ for (const string& line : v) {
+ if (first) {
+ result = strings::StrCat(Spaces(initial), line, "\n");
+ first = false;
+ } else {
+ if (line.empty()) {
+ strings::StrAppend(&result, "\n");
+ } else {
+ strings::StrAppend(&result, Spaces(rest), line, "\n");
+ }
+ }
+ }
+ return result;
+}
+
+// Adds append to *dest, with a space if the first line will be <= width,
+// or a newline otherwise.
+void AppendWithinWidth(string* dest, StringPiece append, int width) {
+ auto first_line = append.find('\n');
+ if (first_line == string::npos) first_line = append.size();
+ if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) {
+ strings::StrAppend(dest, "\n", append);
+ } else {
+ strings::StrAppend(dest, " ", append);
+ }
+}
+
+void RemoveDescriptionsFromOpDef(OpDef* op_def) {
+ for (int i = 0; i < op_def->input_arg_size(); ++i) {
+ op_def->mutable_input_arg(i)->clear_description();
+ }
+ for (int i = 0; i < op_def->output_arg_size(); ++i) {
+ op_def->mutable_output_arg(i)->clear_description();
+ }
+ for (int i = 0; i < op_def->attr_size(); ++i) {
+ op_def->mutable_attr(i)->clear_description();
+ }
+ op_def->clear_summary();
+ op_def->clear_description();
+}
+
+// Like DataTypeString() but uses the Python names for the
+// float types.
+string PythonDataTypeString(DataType dtype) {
+ switch (dtype) {
+ case DT_FLOAT:
+ return "float32";
+ case DT_DOUBLE:
+ return "float64";
+ default:
+ return DataTypeString(dtype);
+ }
+}
+
+string TypeString(DataType dtype, bool ref) {
+ if (ref) {
+ return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`");
+ } else {
+ return strings::StrCat("`", PythonDataTypeString(dtype), "`");
+ }
+}
+
+string TypeListString(const AttrValue& value) {
+ string ret;
+ for (int t : value.list().type()) {
+ if (!ret.empty()) strings::StrAppend(&ret, ", ");
+ DataType dtype = static_cast<DataType>(t);
+ if (IsRefType(dtype)) {
+ strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)),
+ " mutable");
+ } else {
+ strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`");
+ }
+ }
+ return ret;
+}
+
+string SingleTensorName(DataType dtype, bool is_ref) {
+ const string type_str = TypeString(dtype, is_ref);
+ return strings::StrCat("A `Tensor` of type ", type_str, ".");
+}
+
+const char kUnknownTensorType[] = {"A `Tensor`."};
+
+string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg,
+ const std::unordered_map<string, string>& inferred_attrs,
+ bool is_output) {
+ if (!arg.number_attr().empty()) {
+ // N Tensors with the same type
+ const string* original_arg =
+ gtl::FindOrNull(inferred_attrs, arg.number_attr());
+ string prefix;
+ if (original_arg == nullptr) {
+ prefix = strings::StrCat("A list of `", arg.number_attr(), "`");
+ } else if (*original_arg == arg.name()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
+ if (attr->has_minimum() && attr->minimum() > 0) {
+ prefix = strings::StrCat("A list of at least ", attr->minimum());
+ } else {
+ prefix = "A list of";
+ }
+ } else {
+ prefix = strings::StrCat(
+ "A list with the same number of `Tensor` objects as `",
+ AvoidPythonReserved(*original_arg), "` of");
+ }
+
+ if (arg.type() != DT_INVALID) {
+ return strings::StrCat(prefix, " `Tensor` objects of type ",
+ TypeString(arg.type(), arg.is_ref()), ".");
+ } else {
+ original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr());
+ if (arg.is_ref()) {
+ strings::StrAppend(&prefix, " mutable");
+ }
+ if (original_arg == nullptr) {
+ return strings::StrCat(prefix, " `Tensor` objects of type ",
+ arg.type_attr(), ".");
+ } else if (*original_arg == arg.name()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
+ if (attr->has_allowed_values()) {
+ return strings::StrCat(prefix,
+ " `Tensor` objects of the same type in: ",
+ TypeListString(attr->allowed_values()), ".");
+ } else {
+ return strings::StrCat(prefix, " `Tensor` objects of the same type.");
+ }
+ } else {
+ return strings::StrCat(prefix, " `Tensor` objects of the same type as ",
+ AvoidPythonReserved(*original_arg), ".");
+ }
+ }
+ } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) {
+ const bool is_list = !arg.type_list_attr().empty();
+ const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr();
+ const OpDef::AttrDef* attr = FindAttr(attr_name, op_def);
+ const string mutable_str = arg.is_ref() ? "mutable " : "";
+ const string prefix =
+ is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects")
+ : strings::StrCat("A ", mutable_str, "`Tensor`");
+ const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name);
+ if (original_arg == nullptr) {
+ return strings::StrCat(prefix, " of type `", attr_name, "`.");
+ } else if (*original_arg == arg.name()) {
+ if (attr->has_allowed_values()) {
+ if (is_list) {
+ return strings::StrCat(prefix, " with types from: ",
+ TypeListString(attr->allowed_values()), ".");
+ } else {
+ return strings::StrCat(
+ prefix, is_output ? ". Has one of the following types: "
+ : ". Must be one of the following types: ",
+ TypeListString(attr->allowed_values()), ".");
+ }
+ } else {
+ return strings::StrCat(prefix, ".");
+ }
+ } else {
+ return strings::StrCat(prefix,
+ is_output ? ". Has the same type as `"
+ : ". Must have the same type as `",
+ AvoidPythonReserved(*original_arg), "`.");
+ }
+ } else {
+ return SingleTensorName(arg.type(), arg.is_ref());
+ }
+}
+
+void PrintReturns(const OpDef& op_def,
+ const std::vector<string>& output_type_string) {
+ DCHECK_EQ(op_def.output_arg_size(), output_type_string.size());
+ const int num_outs = op_def.output_arg_size();
+ printf("\n Returns:\n");
+ if (num_outs == 0) {
+ printf(" The created Operation.\n");
+ } else {
+ if (num_outs == 1) {
+ StringPiece description = op_def.output_arg(0).description();
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ printf("%s", Indent(4, 4, description).c_str());
+ } else {
+ // Special case of one output, don't use the name of the output unless
+ // there is no description.
+ string desc = output_type_string.empty() ? kUnknownTensorType
+ : output_type_string[0];
+ if (desc == kUnknownTensorType) {
+ // Special case where we don't understand how the output tensor type
+ // depends on the input tensor types, just use the output arg
+ // description if we can.
+ if (!description.empty()) {
+ desc = op_def.output_arg(0).description();
+ } else if (!op_def.output_arg(0).name().empty()) {
+ desc = strings::StrCat(" The ", op_def.output_arg(0).name(),
+ " `Tensor`.");
+ }
+ } else if (!description.empty()) {
+ AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
+ }
+ printf("%s", Indent(4, 4, desc).c_str());
+ }
+ } else {
+ std::vector<string> out_names(num_outs);
+ for (int i = 0; i < num_outs; ++i) {
+ if (!op_def.output_arg(i).name().empty()) {
+ out_names[i] = op_def.output_arg(i).name();
+ } else {
+ out_names[i] = strings::StrCat("output", i);
+ }
+ }
+ printf(" A tuple of `Tensor` objects (%s).\n",
+ str_util::Join(out_names, ", ").c_str());
+ for (int i = 0; i < num_outs; ++i) {
+ string desc = strings::StrCat(out_names[i], ": ");
+ StringPiece description = op_def.output_arg(i).description();
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ strings::StrAppend(&desc, description);
+ } else {
+ const string type = static_cast<size_t>(i) < output_type_string.size()
+ ? output_type_string[i]
+ : kUnknownTensorType;
+ if (!description.empty()) {
+ if (type == kUnknownTensorType) {
+ // Special case where we don't understand how the output tensor
+ // type depends on the input tensor types, so we just use the
+ // output arg description.
+ strings::StrAppend(&desc, description);
+ } else {
+ strings::StrAppend(&desc, type, " ", description);
+ }
+ } else {
+ strings::StrAppend(&desc, type);
+ }
+ }
+ printf("%s", Indent(4, 6, desc).c_str());
+ }
+ }
+ }
+}
+
+string StringToPython(const string& str) {
+ return strings::StrCat("\"", str_util::CEscape(str), "\"");
+}
+
+string DataTypeToPython(DataType dtype) {
+ return strings::StrCat("tf.", PythonDataTypeString(dtype));
+}
+
+string ShapeToPython(const TensorShapeProto& shape) {
+ string python = "[";
+ for (const auto& dim : shape.dim()) {
+ if (python.size() > 1) strings::StrAppend(&python, ", ");
+ if (!dim.name().empty()) {
+ strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ",
+ dim.size(), ")");
+ } else {
+ strings::StrAppend(&python, dim.size());
+ }
+ }
+ strings::StrAppend(&python, "]");
+ return python;
+}
+
+string AttrListToPython(const AttrValue& value) {
+ string ret;
+ if (value.list().s_size() > 0) {
+ for (int i = 0; i < value.list().s_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, StringToPython(value.list().s(i)));
+ }
+ } else if (value.list().i_size() > 0) {
+ for (int i = 0; i < value.list().i_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().i(i));
+ }
+ } else if (value.list().f_size() > 0) {
+ for (int i = 0; i < value.list().f_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().f(i));
+ }
+ } else if (value.list().b_size() > 0) {
+ for (int i = 0; i < value.list().b_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().b(i) ? "True" : "False");
+ }
+ } else if (value.list().type_size() > 0) {
+ for (int i = 0; i < value.list().type_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, DataTypeToPython(value.list().type(i)));
+ }
+ } else if (value.list().shape_size() > 0) {
+ for (int i = 0; i < value.list().shape_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, ShapeToPython(value.list().shape(i)));
+ }
+ }
+ return ret;
+}
+
+string AttrValueToPython(const string& type, const AttrValue& value) {
+ if (type == "string") {
+ return StringToPython(value.s());
+ } else if (type == "int") {
+ return strings::StrCat(value.i());
+ } else if (type == "float") {
+ return strings::StrCat(value.f());
+ } else if (type == "bool") {
+ return value.b() ? "True" : "False";
+ } else if (type == "type") {
+ return DataTypeToPython(value.type());
+ } else if (type == "shape") {
+ return ShapeToPython(value.shape());
+ } else {
+ return strings::StrCat("[", AttrListToPython(value), "]");
+ }
+}
+
+// Requires: ValidateOpDef(op_def).ok()
+void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
+ // Map from attr name to the first input arg it is inferred from.
+ std::unordered_map<string, string> inferred_attrs;
+ // This has all the input args followed by those attrs that don't have
+ // defaults.
+ std::vector<string> args_no_default;
+ // The parameters with defaults (these have to be listed after those without).
+ // No input args are included, just attrs and the graph ("g") parameter.
+ std::vector<string> args_with_defaults;
+ for (int i = 0; i < op_def.input_arg_size(); ++i) {
+ const auto& arg(op_def.input_arg(i));
+ args_no_default.push_back(arg.name());
+ if (!arg.type_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name());
+ } else if (!arg.type_list_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(),
+ arg.name());
+ }
+ if (!arg.number_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name());
+ }
+ }
+ for (int i = 0; i < op_def.attr_size(); ++i) {
+ const auto& attr(op_def.attr(i));
+ // Do not add inferred attrs to the Python function signature.
+ if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) {
+ if (attr.has_default_value()) {
+ args_with_defaults.push_back(attr.name());
+ } else {
+ args_no_default.push_back(attr.name());
+ }
+ }
+ }
+
+ // Save the list of attr parameters (attrs that won't be inferred),
+ // those with defaults go at the end.
+ std::vector<string> attrs;
+ // Get the attrs in the order we want by taking the attrs without defaults
+ // from the end of args_no_default, and adding args_no_default (before
+ // "g" gets added to args_no_default, so it only has attrs).
+ attrs.reserve(args_no_default.size() - op_def.input_arg_size() +
+ args_with_defaults.size());
+ attrs.insert(attrs.end(), args_no_default.begin() + op_def.input_arg_size(),
+ args_no_default.end());
+ attrs.insert(attrs.end(), args_with_defaults.begin(),
+ args_with_defaults.end());
+
+ std::vector<string> param_names;
+ param_names.reserve(args_no_default.size() + args_with_defaults.size());
+ string parameters;
+ for (const string& name : args_no_default) {
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ const string param = AvoidPythonReserved(name);
+ strings::StrAppend(&parameters, param);
+ param_names.push_back(param);
+ }
+ for (const string& name : args_with_defaults) {
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ const string param = AvoidPythonReserved(name);
+ strings::StrAppend(&parameters, param, "=None");
+ param_names.push_back(param);
+ }
+ const bool has_args = args_no_default.size() + args_with_defaults.size() > 0;
+
+ // Print: def Function(parameters):
+ const string lower_op_name = strings::StrCat(is_hidden ? "_" : "", op_name);
+
+ const string def_prefix = strings::StrCat("def ", lower_op_name, "(");
+ const string def_suffix =
+ strings::StrCat(parameters, has_args ? ", " : "", "name=None):");
+
+ printf("%s\n", WordWrap(def_prefix, def_suffix, kRightMargin).c_str());
+
+ // Format the Op's descriptions so that it can be a Python docstring.
+ string comment;
+ if (op_def.summary().empty()) {
+ comment = "TODO: add doc.\n";
+ } else {
+ comment = strings::StrCat(op_def.summary(), "\n");
+ if (!op_def.description().empty()) {
+ strings::StrAppend(&comment, "\n", Indent(2, 2, op_def.description()));
+ }
+ }
+
+ printf(R"( r"""%s
+ Args:
+)",
+ comment.c_str());
+
+ // Inputs
+ for (int i = 0; i < op_def.input_arg_size(); ++i) {
+ const auto& arg(op_def.input_arg(i));
+ StringPiece description = op_def.input_arg(i).description();
+ string desc;
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ desc = strings::StrCat(param_names[i], ": ");
+ } else {
+ desc = strings::StrCat(param_names[i], ": ",
+ ArgTypeName(op_def, arg, inferred_attrs, false));
+ }
+ if (!description.empty()) {
+ AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
+ }
+ printf("%s", Indent(4, 6, desc).c_str());
+ }
+
+ // Attrs
+ for (const string& name : attrs) {
+ const auto& attr = *FindAttr(name, op_def);
+ string desc = strings::StrCat(AvoidPythonReserved(name), ": ");
+
+ static const char* const kAttrTypeName[][2] = {
+ {"string", "`string`"},
+ {"list(string)", "list of `strings`"},
+ {"int", "`int`"},
+ {"list(int)", "list of `ints`"},
+ {"float", "`float`"},
+ {"list(float)", "list of `floats`"},
+ {"bool", "`bool`"},
+ {"list(bool)", "list of `bools`"},
+ {"type", "`tf.DType`"},
+ {"list(type)", "list of `tf.DTypes`"},
+ {"shape", "`tf.TensorShape` or list of `ints`"},
+ {"list(shape)",
+ "list of shapes (each a `tf.TensorShape` or list of `ints`)"},
+ };
+ for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
+ if (attr.type() == kAttrTypeName[i][0]) {
+ string s;
+ if (attr.has_default_value()) {
+ s = strings::StrCat("optional ", kAttrTypeName[i][1]);
+ } else {
+ s = kAttrTypeName[i][1];
+ }
+ if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) {
+ strings::StrAppend(&desc, "An ", s);
+ } else {
+ strings::StrAppend(&desc, "A ", s);
+ }
+ break;
+ }
+ }
+
+ if (attr.has_allowed_values()) {
+ strings::StrAppend(&desc, " from: `",
+ AttrListToPython(attr.allowed_values()), "`");
+ }
+
+ if (attr.has_minimum()) {
+ if (attr.type() == "int") {
+ strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`");
+ } else if (attr.minimum() > 0) {
+ strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`");
+ }
+ }
+
+ strings::StrAppend(&desc, ".");
+
+ if (attr.has_default_value()) {
+ strings::StrAppend(&desc, " Defaults to `",
+ AttrValueToPython(attr.type(), attr.default_value()),
+ "`.");
+ }
+
+ if (!attr.description().empty()) {
+ AppendWithinWidth(&desc, attr.description(),
+ kRightMargin - 4 /* indent */);
+ }
+ printf("%s", Indent(4, 6, desc).c_str());
+ }
+
+ printf(" name: A name for the operation (optional).\n");
+
+ std::vector<string> output_type_string;
+ output_type_string.reserve(op_def.output_arg_size());
+ for (int i = 0; i < op_def.output_arg_size(); ++i) {
+ output_type_string.push_back(
+ ArgTypeName(op_def, op_def.output_arg(i), inferred_attrs, true));
+ }
+ PrintReturns(op_def, output_type_string);
+
+ string return_prefix = strings::StrCat(" return _op_def_lib.apply_op(");
+ string return_args = strings::StrCat("\"", op_def.name(), "\", ");
+ for (size_t i = 0; i < param_names.size(); ++i) {
+ strings::StrAppend(&return_args, param_names[i], "=", param_names[i], ", ");
+ }
+ strings::StrAppend(&return_args, "name=name)");
+
+ printf(R"( """
+%s
+)",
+ // Wrap the arguments, and indent to the (.
+ WordWrap(return_prefix, return_args, kRightMargin).c_str());
+
+ printf("\n\n");
+}
+
+void GenerateLowerCaseOpName(const string& str, string* result) {
+ char joiner = '_';
+ int last_index = str.size() - 1;
+ for (int i = 0; i <= last_index; ++i) {
+ char c = str[i];
+ // Emit a joiner only if a previous-lower-to-now-upper or a
+ // now-upper-to-next-lower transition happens.
+ if (isupper(c) && (i > 0)) {
+ if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
+ result->push_back(joiner);
+ }
+ }
+ result->push_back(tolower(c));
+ }
+}
+
+} // namespace
+
+void PrintPythonOps(const OpList& ops, const string& hidden_ops,
+ bool require_shapes) {
+ // Header
+ // TODO(josh11b): Mention the library for which wrappers are being generated.
+ printf(R"("""Python wrappers around Brain.
+
+This file is MACHINE GENERATED! Do not edit.
+"""
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.python.framework import op_def_registry
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import op_def_library
+
+
+)");
+
+ std::vector<string> hidden_vec = str_util::Split(hidden_ops, ',');
+
+ // We'll make a copy of ops that filters out descriptions.
+ OpList cleaned_ops;
+ auto out = cleaned_ops.mutable_op();
+ out->Reserve(ops.op_size());
+ for (const auto& op_def : ops.op()) {
+ bool is_hidden = false;
+ for (const string& hidden : hidden_vec) {
+ if (op_def.name() == hidden) {
+ is_hidden = true;
+ break;
+ }
+ }
+
+ // PrintPythonOp(op_def, is_hidden, op_def.name());
+ string lower_case_name;
+ GenerateLowerCaseOpName(op_def.name(), &lower_case_name);
+
+ // When users create custom python wrappers, they may link in the
+ // default op registry by accident, and because they can't
+ // enumerate all 'hidden' symbols, this guard is to prevent
+ // instantiating a python reserved word in their wrapper.
+ if (!is_hidden && IsPythonReserved(lower_case_name)) {
+ continue;
+ }
+
+ PrintPythonOp(op_def, is_hidden, lower_case_name);
+
+ if (!require_shapes) {
+ printf("ops.RegisterShape(\"%s\")(None)\n", op_def.name().c_str());
+ }
+
+ auto added = out->Add();
+ *added = op_def;
+ RemoveDescriptionsFromOpDef(added);
+ }
+
+ printf(R"(def _InitOpDefLibrary():
+ op_list = op_def_pb2.OpList()
+ text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
+ op_def_registry.register_op_list(op_list)
+ op_def_lib = op_def_library.OpDefLibrary()
+ op_def_lib.add_op_list(op_list)
+ return op_def_lib
+
+
+_InitOpDefLibrary.op_list_ascii = """%s"""
+
+
+_op_def_lib = _InitOpDefLibrary()
+)",
+ cleaned_ops.DebugString().c_str());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h
new file mode 100644
index 0000000000..488f7431e0
--- /dev/null
+++ b/tensorflow/python/framework/python_op_gen.h
@@ -0,0 +1,17 @@
+#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_
+#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_
+
+#include <string>
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Result is printed to stdout. hidden_ops should be a comma-separated
+// list of Op names that should get a leading _ in the output.
+void PrintPythonOps(const OpList& ops, const string& hidden_ops,
+ bool require_shapes);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc
new file mode 100644
index 0000000000..29afe35598
--- /dev/null
+++ b/tensorflow/python/framework/python_op_gen_main.cc
@@ -0,0 +1,30 @@
+#include "tensorflow/python/framework/python_op_gen.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace {
+
+void PrintAllPythonOps(const char* hidden, bool require_shapes) {
+ OpList ops;
+ OpRegistry::Global()->Export(false, &ops);
+ PrintPythonOps(ops, hidden, require_shapes);
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (argc == 2) {
+ tensorflow::PrintAllPythonOps("", std::string(argv[1]) == "1");
+ } else if (argc == 3) {
+ tensorflow::PrintAllPythonOps(argv[1], std::string(argv[2]) == "1");
+ } else {
+ return -1;
+ }
+ return 0;
+}
diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py
new file mode 100644
index 0000000000..d0ffee7042
--- /dev/null
+++ b/tensorflow/python/framework/random_seed.py
@@ -0,0 +1,136 @@
+"""For seeding individual ops based on a graph-level seed.
+"""
+
+from tensorflow.python.framework import ops
+
+
+_DEFAULT_GRAPH_SEED = 87654321
+
+
+def get_seed(op_seed):
+ """Returns the local seeds an operation should use given an op-specific seed.
+
+ Given operation-specific seed, `op_seed`, this helper function returns two
+ seeds derived from graph-level and op-level seeds. Many random operations
+ internally use the two seeds to allow user to change the seed globally for a
+ graph, or for only specific operations.
+
+ For details on how the graph-level seed interacts with op seeds, see
+ [`set_random_seed`](constant_op.md#set_random_seed).
+
+ Args:
+ op_seed: integer.
+
+ Returns:
+ A tuple of two integers that should be used for the local seed of this
+ operation.
+ """
+ graph_seed = ops.get_default_graph().seed
+ if graph_seed is not None:
+ if op_seed is not None:
+ return graph_seed, op_seed
+ else:
+ return graph_seed, ops.get_default_graph()._last_id
+ else:
+ if op_seed is not None:
+ return _DEFAULT_GRAPH_SEED, op_seed
+ else:
+ return None, None
+
+
+def set_random_seed(seed):
+ """Sets the graph-level random seed.
+
+ Operations that rely on a random seed actually derive it from two seeds:
+ the graph-level and operation-level seeds. This sets the graph-level seed.
+
+ Its interactions with operation-level seeds is as follows:
+
+ 1. If neither the graph-level nor the operation seed is set:
+ A random seed is used for this op.
+ 2. If the graph-level seed is set, but the operation seed is not:
+ The system deterministically picks an operation seed in conjunction
+ with the graph-level seed so that it gets a unique random sequence.
+ 3. If the graph-level seed is not set, but the operation seed is set:
+ A default graph-level seed and the specified operation seed are used to
+ determine the random sequence.
+ 4. If both the graph-level and the operation seed are set:
+ Both seeds are used in conjunction to determine the random sequence.
+
+ To illustrate the user-visible effects, consider these examples:
+
+ To generate different sequences across sessions, set neither
+ graph-level nor op-level seeds:
+
+ ```python
+ a = tf.random_uniform([1])
+ b = tf.random_normal([1])
+
+ print "Session 1"
+ with tf.Session() as sess1:
+ print sess1.run(a) # generates 'A1'
+ print sess1.run(a) # generates 'A2'
+ print sess1.run(b) # generates 'B1'
+ print sess1.run(b) # generates 'B2'
+
+ print "Session 2"
+ with tf.Session() as sess2:
+ print sess2.run(a) # generates 'A3'
+ print sess2.run(a) # generates 'A4'
+ print sess2.run(b) # generates 'B3'
+ print sess2.run(b) # generates 'B4'
+ ```
+
+ To generate the same repeatable sequence for an op across sessions, set the
+ seed for the op:
+
+ ```python
+ a = tf.random_uniform([1], seed=1)
+ b = tf.random_normal([1])
+
+ # Repeatedly running this block with the same graph will generate the same
+ # sequence of values for 'a', but different sequences of values for 'b'.
+ print "Session 1"
+ with tf.Session() as sess1:
+ print sess1.run(a) # generates 'A1'
+ print sess1.run(a) # generates 'A2'
+ print sess1.run(b) # generates 'B1'
+ print sess1.run(b) # generates 'B2'
+
+ print "Session 2"
+ with tf.Session() as sess2:
+ print sess2.run(a) # generates 'A1'
+ print sess2.run(a) # generates 'A2'
+ print sess2.run(b) # generates 'B3'
+ print sess2.run(b) # generates 'B4'
+ ```
+
+ To make the random sequences generated by all ops be repeatable across
+ sessions, set a graph-level seed:
+
+ ```python
+ tf.set_random_seed(1234)
+ a = tf.random_uniform([1])
+ b = tf.random_normal([1])
+
+ # Repeatedly running this block with the same graph will generate different
+ # sequences of 'a' and 'b'.
+ print "Session 1"
+ with tf.Session() as sess1:
+ print sess1.run(a) # generates 'A1'
+ print sess1.run(a) # generates 'A2'
+ print sess1.run(b) # generates 'B1'
+ print sess1.run(b) # generates 'B2'
+
+ print "Session 2"
+ with tf.Session() as sess2:
+ print sess2.run(a) # generates 'A1'
+ print sess2.run(a) # generates 'A2'
+ print sess2.run(b) # generates 'B1'
+ print sess2.run(b) # generates 'B2'
+ ```
+
+ Args:
+ seed: integer.
+ """
+ ops.get_default_graph().seed = seed
diff --git a/tensorflow/python/framework/registry.py b/tensorflow/python/framework/registry.py
new file mode 100644
index 0000000000..d9556f0a06
--- /dev/null
+++ b/tensorflow/python/framework/registry.py
@@ -0,0 +1,64 @@
+"""Registry mechanism for "registering" classes/functions for general use.
+
+This is typically used with a decorator that calls Register for adding
+a class or function to a registry.
+"""
+
+import traceback
+
+from tensorflow.python.platform import logging
+
+
+# Registry mechanism below is based on mapreduce.python.mrpython.Register.
+_LOCATION_TAG = "location"
+_TYPE_TAG = "type"
+
+
+class Registry(object):
+ """Provides a registry for saving objects."""
+
+ def __init__(self, name):
+ """Creates a new registry."""
+ self._name = name
+ self._registry = dict()
+
+ def register(self, candidate, name=None):
+ """Registers a Python object "candidate" for the given "name".
+
+ Args:
+ candidate: the candidate object to add to the registry.
+ name: an optional string specifying the registry key for the candidate.
+ If None, candidate.__name__ will be used.
+ Raises:
+ KeyError: If same name is used twice.
+ """
+ if not name:
+ name = candidate.__name__
+ if name in self._registry:
+ (filename, line_number, function_name, _) = (
+ self._registry[name][_LOCATION_TAG])
+ raise KeyError("Registering two %s with name '%s' !"
+ "(Previous registration was in %s %s:%d)" %
+ (self._name, name, function_name, filename, line_number))
+
+ logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name)
+ # stack trace is [this_function, Register(), user_function,...]
+ # so the user function is #2.
+ stack = traceback.extract_stack()
+ self._registry[name] = {_TYPE_TAG: candidate, _LOCATION_TAG: stack[2]}
+
+ def lookup(self, name):
+ """Looks up "name".
+
+ Args:
+ name: a string specifying the registry key for the candidate.
+ Returns:
+ Registered object if found
+ Raises:
+ LookupError: if "name" has not been registered.
+ """
+ if name in self._registry:
+ return self._registry[name][_TYPE_TAG]
+ else:
+ raise LookupError(
+ "%s registry has no entry for: %s" % (self._name, name))
diff --git a/tensorflow/python/framework/registry_test.py b/tensorflow/python/framework/registry_test.py
new file mode 100644
index 0000000000..5b4f261ceb
--- /dev/null
+++ b/tensorflow/python/framework/registry_test.py
@@ -0,0 +1,38 @@
+"""Tests for tensorflow.ops.registry."""
+
+from tensorflow.python.framework import registry
+from tensorflow.python.platform import googletest
+
+
+class RegistryTest(googletest.TestCase):
+
+ class Foo(object):
+ pass
+
+ def testRegisterClass(self):
+ myreg = registry.Registry('testfoo')
+ with self.assertRaises(LookupError):
+ myreg.lookup('Foo')
+ myreg.register(RegistryTest.Foo, 'Foo')
+ assert myreg.lookup('Foo') == RegistryTest.Foo
+
+ def testRegisterFunction(self):
+ myreg = registry.Registry('testbar')
+ with self.assertRaises(LookupError):
+ myreg.lookup('Bar')
+ myreg.register(bar, 'Bar')
+ assert myreg.lookup('Bar') == bar
+
+ def testDuplicate(self):
+ myreg = registry.Registry('testbar')
+ myreg.register(bar, 'Bar')
+ with self.assertRaises(KeyError):
+ myreg.register(bar, 'Bar')
+
+
+def bar():
+ pass
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
new file mode 100644
index 0000000000..d4f27696d4
--- /dev/null
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -0,0 +1,743 @@
+"""Helper classes for tensor shape inference."""
+import tensorflow.python.platform
+
+
+class Dimension(object):
+ """Represents the value of one dimension in a TensorShape."""
+
+ def __init__(self, value):
+ """Creates a new Dimension with the given value."""
+ if value is None:
+ self._value = None
+ else:
+ self._value = int(value)
+
+ def __repr__(self):
+ return "Dimension(%s)" % repr(self._value)
+
+ def __eq__(self, other):
+ """Returns true if `other` has the same known value as this Dimension."""
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ return self._value == other.value
+
+ def __ne__(self, other):
+ """Returns true if `other` has a different known value from `self`."""
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ return self._value != other.value
+
+ def __int__(self):
+ return self._value
+
+ @property
+ def value(self):
+ """The value of this dimension, or None if it is unknown."""
+ return self._value
+
+ def is_compatible_with(self, other):
+ """Returns true if `other` is compatible with this Dimension.
+
+ Two known Dimensions are compatible if they have the same value.
+ An unknown Dimension is compatible with all other Dimensions.
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ True if this Dimension and `other` are compatible.
+ """
+ other = as_dimension(other)
+ return (self._value is None
+ or other.value is None
+ or self._value == other.value)
+
+ def assert_is_compatible_with(self, other):
+ """Raises an exception if `other` is not compatible with this Dimension.
+
+ Args:
+ other: Another Dimension.
+
+ Raises:
+ ValueError: If `self` and `other` are not compatible (see
+ is_compatible_with).
+ """
+ if not self.is_compatible_with(other):
+ raise ValueError("Dimensions %s and %s are not compatible"
+ % (self, other))
+
+ def merge_with(self, other):
+ """Returns a Dimension that combines the information in `self` and `other`.
+
+ Dimensions are combined as follows:
+
+ Dimension(n) .merge_with(Dimension(n)) == Dimension(n)
+ Dimension(n) .merge_with(Dimension(None)) == Dimension(n)
+ Dimension(None).merge_with(Dimension(n)) == Dimension(n)
+ Dimension(None).merge_with(Dimension(None)) == Dimension(None)
+ Dimension(n) .merge_with(Dimension(m)) raises ValueError for n != m
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension containing the combined information of `self` and
+ `other`.
+
+ Raises:
+ ValueError: If `self` and `other` are not compatible (see
+ is_compatible_with).
+ """
+ other = as_dimension(other)
+ self.assert_is_compatible_with(other)
+ if self._value is None:
+ return Dimension(other.value)
+ else:
+ return Dimension(self._value)
+
+ def __add__(self, other):
+ """Returns the sum of `self` and `other`.
+
+ Dimensions are summed as follows:
+
+ Dimension(m) + Dimension(n) == Dimension(m + n)
+ Dimension(m) + Dimension(None) == Dimension(None)
+ Dimension(None) + Dimension(n) == Dimension(None)
+ Dimension(None) + Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the sum of `self` and `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value + other.value)
+
+ def __sub__(self, other):
+ """Returns the subtraction of `other` from `self`.
+
+ Dimensions are subtracted as follows:
+
+ Dimension(m) - Dimension(n) == Dimension(m - n)
+ Dimension(m) - Dimension(None) == Dimension(None)
+ Dimension(None) - Dimension(n) == Dimension(None)
+ Dimension(None) - Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the subtraction of sum of `other` from `self`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value - other.value)
+
+ def __mul__(self, other):
+ """Returns the product of `self` and `other`.
+
+ Dimensions are summed as follows:
+
+ Dimension(m) * Dimension(n) == Dimension(m * n)
+ Dimension(m) * Dimension(None) == Dimension(None)
+ Dimension(None) * Dimension(n) == Dimension(None)
+ Dimension(None) * Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the sum of `self` and `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value * other.value)
+
+ def __div__(self, other):
+ """Returns the quotient of `self` and `other`.
+
+ Dimensions are summed as follows:
+
+ Dimension(m) / Dimension(n) == Dimension(m / n)
+ Dimension(m) / Dimension(None) == Dimension(None)
+ Dimension(None) / Dimension(n) == Dimension(None)
+ Dimension(None) / Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is the sum of `self` and `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value / other.value)
+
+ def __mod__(self, other):
+ """Returns `self` modulo `other.
+
+ Dimension moduli are computed as follows:
+
+ Dimension(m) % Dimension(n) == Dimension(m % n)
+ Dimension(m) % Dimension(None) == Dimension(None)
+ Dimension(None) % Dimension(n) == Dimension(None)
+ Dimension(None) % Dimension(None) == Dimension(None)
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ A Dimension whose value is `self` modulo `other`.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return Dimension(None)
+ else:
+ return Dimension(self._value % other.value)
+
+ def __lt__(self, other):
+ """Returns True if `self` is known to be less than `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) < Dimension(n) == m < n
+ Dimension(m) < Dimension(None) == None
+ Dimension(None) < Dimension(n) == None
+ Dimension(None) < Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value < other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value < other.value
+
+ def __le__(self, other):
+ """Returns True if `self` is known to be less than or equal to `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) <= Dimension(n) == m <= n
+ Dimension(m) <= Dimension(None) == None
+ Dimension(None) <= Dimension(n) == None
+ Dimension(None) <= Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value <= other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value <= other.value
+
+ def __gt__(self, other):
+ """Returns True if `self` is known to be greater than `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) > Dimension(n) == m > n
+ Dimension(m) > Dimension(None) == None
+ Dimension(None) > Dimension(n) == None
+ Dimension(None) > Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value > other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value > other.value
+
+ def __ge__(self, other):
+ """Returns True if `self` is known to be greater than or equal to `other`.
+
+ Dimensions are compared as follows:
+
+ Dimension(m) >= Dimension(n) == m >= n
+ Dimension(m) >= Dimension(None) == None
+ Dimension(None) >= Dimension(n) == None
+ Dimension(None) >= Dimension(None) == None
+
+ Args:
+ other: Another Dimension.
+
+ Returns:
+ The value of `self.value >= other.value` if both are known, otherwise
+ None.
+ """
+ other = as_dimension(other)
+ if self._value is None or other.value is None:
+ return None
+ else:
+ return self._value >= other.value
+
+
+def as_dimension(value):
+ """Converts the given value to a Dimension.
+
+ A Dimenson input will be returned unmodified.
+ An input of `None` will be converted to an unknown Dimension.
+ An integer input will be converted to a Dimension with that value.
+
+ Args:
+ value: The value to be converted.
+
+ Returns:
+ A Dimension corresponding to the given value.
+ """
+ if isinstance(value, Dimension):
+ return value
+ else:
+ return Dimension(value)
+
+
+class TensorShape(object):
+ """Represents the shape of a `Tensor`.
+
+ A `TensorShape` represents a possibly-partial shape specification for a
+ `Tensor`. It may be one of the following:
+
+ * *Fully-known shape:* has a known number of dimensions and a known size
+ for each dimension.
+ * *Partially-known shape:* has a known number of dimensions, and an unknown
+ size for one or more dimension.
+ * *Unknown shape:* has an unknown number of dimensions, and an unknown
+ size in all dimensions.
+
+ If a tensor is produced by an operation of type `"Foo"`, its shape
+ may be inferred if there is a registered shape function for
+ `"Foo"`. See [`tf.RegisterShape()`](framework.md#RegisterShape)
+ for details of shape
+ functions and how to register them. Alternatively, the shape may be set
+ explicitly using [`Tensor.set_shape()`](framework.md#Tensor.set_shape).
+
+ @@merge_with
+ @@concatenate
+
+ @@ndims
+ @@dims
+ @@as_list
+ @@is_compatible_with
+ @@is_fully_defined
+
+ @@with_rank
+ @@with_rank_at_least
+ @@with_rank_at_most
+
+ @@assert_has_rank
+ @@assert_same_rank
+ @@assert_is_compatible_with
+ @@assert_is_fully_defined
+ """
+
+ def __init__(self, dims):
+ """Creates a new TensorShape with the given dimensions.
+
+ Args:
+ dims: A list of Dimensions, or None if the shape is unspecified.
+ DEPRECATED: A single integer is treated as a singleton list.
+ """
+ # TODO(irving): Eliminate the single integer special case.
+ if dims is None:
+ self._dims = None
+ else:
+ try:
+ dims_iter = iter(dims)
+ except TypeError:
+ # Treat as a singleton dimension
+ self._dims = [as_dimension(dims)]
+ else:
+ # Got a list of dimensions
+ self._dims = map(as_dimension, dims_iter)
+
+ def __repr__(self):
+ return "TensorShape(%s)" % str(self._dims)
+
+ @property
+ def dims(self):
+ """Returns a list of Dimensions, or None if the shape is unspecified."""
+ return self._dims
+
+ @property
+ def ndims(self):
+ """Returns the rank of this shape, or None if it is unspecified."""
+ if self._dims is None:
+ return None
+ else:
+ return len(self._dims)
+
+ def __len__(self):
+ """Returns the rank of this shape, or raises ValueError if unspecified."""
+ if self._dims is None:
+ raise ValueError("Cannot take the length of Shape with unknown rank.")
+ return len(self._dims)
+
+ def __nonzero__(self):
+ """Returns True if this shape contains non-zero information."""
+ return self._dims is not None
+
+ def __getitem__(self, key):
+ """Returns the value of a dimension or a shape, depending on the key.
+
+ Args:
+ key: If `key` is an integer, returns the dimension at that index;
+ otherwise if `key` is a slice, returns a TensorShape whose
+ dimensions are those selected by the slice from `self`.
+
+ Returns:
+ A dimension if `key` is an integer, or a `TensorShape` if `key` is a
+ slice.
+
+ Raises:
+ ValueError: If `key` is a slice, and any of its elements are negative, or
+ if `self` is completely unknown and the step is set.
+ """
+ if self._dims is not None:
+ if isinstance(key, slice):
+ return TensorShape(self._dims[key])
+ else:
+ return self._dims[key]
+ else:
+ if isinstance(key, slice):
+ start = key.start if key.start is not None else 0
+ stop = key.stop
+
+ if key.step is not None:
+ # TODO(mrry): Handle these maybe.
+ raise ValueError("Steps are not yet handled")
+ if stop is None:
+ # NOTE(mrry): This implies that TensorShape(None) is compatible with
+ # TensorShape(None)[1:], which is obviously not true. It would be
+ # possible to track the number of dimensions symbolically,
+ # and perhaps we should do that.
+ return unknown_shape()
+ elif start < 0 or stop < 0:
+ # TODO(mrry): Handle this better, as it will be useful for handling
+ # suffixes of otherwise unknown shapes.
+ return unknown_shape()
+ else:
+ return unknown_shape(ndims=stop-start)
+ else:
+ return Dimension(None)
+
+ def num_elements(self):
+ """Returns the total number of elements, or none for incomplete shapes."""
+ if self.is_fully_defined():
+ size = 1
+ for dim in self._dims:
+ size *= dim.value
+ return size
+ else:
+ return None
+
+ def merge_with(self, other):
+ """Returns a `TensorShape` combining the information in `self` and `other`.
+
+ The dimensions in `self` and `other` are merged elementwise,
+ according to the rules defined for `Dimension.merge_with()`.
+
+ Args:
+ other: Another `TensorShape`.
+
+ Returns:
+ A `TensorShape` containing the combined information of `self` and
+ `other`.
+
+ Raises:
+ ValueError: If `self` and `other` are not compatible.
+ """
+ other = as_shape(other)
+ if self._dims is None:
+ return other
+ else:
+ self.assert_same_rank(other)
+ new_dims = []
+ for i, dim in enumerate(self._dims):
+ new_dims.append(dim.merge_with(other[i]))
+ return TensorShape(new_dims)
+
+ def concatenate(self, other):
+ """Returns the concatenation of the dimension in `self` and `other`.
+
+ *N.B.* If either `self` or `other` is completely unknown,
+ concatenation will discard information about the other shape. In
+ future, we might support concatenation that preserves this
+ information for use with slicing.
+
+ Args:
+ other: Another `TensorShape`.
+
+ Returns:
+ A `TensorShape` whose dimensions are the concatenation of the
+ dimensions in `self` and `other`.
+ """
+ # TODO(mrry): Handle the case where we concatenate a known shape with a
+ # completely unknown shape, so that we can use the partial information.
+ other = as_shape(other)
+ if self._dims is None or other.dims is None:
+ return unknown_shape()
+ else:
+ return TensorShape(self._dims + other.dims)
+
+ def assert_same_rank(self, other):
+ """Raises an exception if `self` and `other` do not have compatible ranks.
+
+ Args:
+ other: Another `TensorShape`.
+
+ Raises:
+ ValueError: If `self` and `other` do not represent shapes with the
+ same rank.
+ """
+ other = as_shape(other)
+ if self.ndims is not None and other.ndims is not None:
+ if self.ndims != other.ndims:
+ raise ValueError(
+ "Shapes %s and %s must have the same rank" % (self, other))
+
+ def assert_has_rank(self, rank):
+ """Raises an exception if `self` is not compatible with the given `rank`.
+
+ Args:
+ rank: An integer.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with the given `rank`.
+ """
+ if self.ndims not in (None, rank):
+ raise ValueError("Shape %s must have rank %d" % (self, rank))
+
+ def with_rank(self, rank):
+ """Returns a shape based on `self` with the given rank.
+
+ This method promotes a completely unknown shape to one with a
+ known rank.
+
+ Args:
+ rank: An integer.
+
+ Returns:
+ A shape that is at least as specific as `self` with the given rank.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with the given `rank`.
+ """
+ return self.merge_with(unknown_shape(ndims=rank))
+
+ def with_rank_at_least(self, rank):
+ """Returns a shape based on `self` with at least the given rank.
+
+ Args:
+ rank: An integer.
+
+ Returns:
+ A shape that is at least as specific as `self` with at least the given
+ rank.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with at least the given
+ `rank`.
+ """
+ if self.ndims is not None and self.ndims < rank:
+ raise ValueError("Shape %s must have rank at least %d" % (self, rank))
+ else:
+ return self
+
+ def with_rank_at_most(self, rank):
+ """Returns a shape based on `self` with at most the given rank.
+
+ Args:
+ rank: An integer.
+
+ Returns:
+ A shape that is at least as specific as `self` with at most the given
+ rank.
+
+ Raises:
+ ValueError: If `self` does not represent a shape with at most the given
+ `rank`.
+ """
+ if self.ndims is not None and self.ndims > rank:
+ raise ValueError("Shape %s must have rank at most %d" % (self, rank))
+ else:
+ return self
+
+ def is_compatible_with(self, other):
+ """Returns True iff `self` is compatible with `other`.
+
+ Two possibly-partially-defined shapes are compatible if there
+ exists a fully-defined shape that both shapes can represent. Thus,
+ compatibility allows the shape inference code to reason about
+ partially-defined shapes. For example:
+
+ * TensorShape(None) is compatible with all shapes.
+
+ * TensorShape([None, None]) is compatible with all two-dimensional
+ shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
+ not compatible with, for example, TensorShape([None]) or
+ TensorShape([None, None, None]).
+
+ * TensorShape([32, None]) is compatible with all two-dimensional shapes
+ with size 32 in the 0th dimension, and also TensorShape([None, None])
+ and TensorShape(None). It is not compatible with, for example,
+ TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
+
+ * TensorShape([32, 784]) is compatible with itself, and also
+ TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
+ None]) and TensorShape(None). It is not compatible with, for example,
+ TensorShape([32, 1, 784]) or TensorShape([None]).
+
+ The compatibility relation is reflexive and symmetric, but not
+ transitive. For example, TensorShape([32, 784]) is compatible with
+ TensorShape(None), and TensorShape(None) is compatible with
+ TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with
+ TensorShape([4, 4]).
+
+ Args:
+ other: Another TensorShape.
+
+ Returns:
+ True iff `self` is compatible with `other`.
+
+ """
+ other = as_shape(other)
+ if self._dims is not None and other.dims is not None:
+ if self.ndims != other.ndims:
+ return False
+ for x_dim, y_dim in zip(self._dims, other.dims):
+ if not x_dim.is_compatible_with(y_dim):
+ return False
+ return True
+
+ def assert_is_compatible_with(self, other):
+ """Raises exception if `self` and `other` do not represent the same shape.
+
+ This method can be used to assert that there exists a shape that both
+ `self` and `other` represent.
+
+ Args:
+ other: Another TensorShape.
+
+ Raises:
+ ValueError: If `self` and `other` do not represent the same shape.
+ """
+ if not self.is_compatible_with(other):
+ raise ValueError("Shapes %s and %s are incompatible" % (self, other))
+
+ def is_fully_defined(self):
+ """Returns True iff `self` is fully defined in every dimension."""
+ return (self._dims is not None
+ and all(dim.value is not None for dim in self._dims))
+
+ def assert_is_fully_defined(self):
+ """Raises an exception if `self` is not fully defined in every dimension.
+
+ Raises:
+ ValueError: If `self` does not have a known value for every dimension.
+ """
+ if not self.is_fully_defined():
+ raise ValueError("Shape %s is not fully defined" % self)
+
+ def as_dimension_list(self):
+ """DEPRECATED: use as_list()."""
+ self.assert_is_fully_defined()
+ return self.as_list()
+
+ def as_list(self):
+ """Returns a list of integers or None for each dimension."""
+ return [dim.value for dim in self._dims]
+
+ def __eq__(self, other):
+ """Returns True if `self` is equivalent to `other`."""
+ other = as_shape(other)
+ return self._dims == other.dims
+
+ def __ne__(self, other):
+ """Returns True if `self` is known to be different from `other`."""
+ other = as_shape(other)
+ if self.ndims is None or other.ndims is None:
+ raise ValueError("The inequality of unknown TensorShapes is undefined.")
+ if self.ndims != other.ndims:
+ return True
+ return self._dims != other.dims
+
+
+def as_shape(shape):
+ """Converts the given object to a TensorShape."""
+ if isinstance(shape, TensorShape):
+ return shape
+ else:
+ return TensorShape(shape)
+
+
+def unknown_shape(ndims=None):
+ """Returns an unknown TensorShape, optionally with a known rank.
+
+ Args:
+ ndims: (Optional) If specified, the number of dimensions in the shape.
+
+ Returns:
+ An unknown TensorShape.
+ """
+ if ndims is None:
+ return TensorShape(None)
+ else:
+ return TensorShape([Dimension(None) for _ in range(ndims)])
+
+
+def scalar():
+ """Returns a shape representing a scalar."""
+ return TensorShape([])
+
+
+def vector(length):
+ """Returns a shape representing a vector.
+
+ Args:
+ length: The length of the vector, which may be None if unknown.
+
+ Returns:
+ A TensorShape representing a vector of the given length.
+ """
+ return TensorShape([length])
+
+
+def matrix(rows, cols):
+ """Returns a shape representing a matrix.
+
+ Args:
+ rows: The number of rows in the matrix, which may be None if unknown.
+ cols: The number of columns in the matrix, which may be None if unknown.
+
+ Returns:
+ A TensorShape representing a matrix of the given size.
+ """
+ return TensorShape([rows, cols])
diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py
new file mode 100644
index 0000000000..9743a8d199
--- /dev/null
+++ b/tensorflow/python/framework/tensor_shape_test.py
@@ -0,0 +1,232 @@
+"""Functional tests for shape inference helper classes."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class DimensionTest(test_util.TensorFlowTestCase):
+
+ def testDimension(self):
+ dim = tensor_shape.Dimension(12)
+ self.assertEqual(12, dim.value)
+ self.assertEqual(12, int(dim))
+ self.assertEqual(dim, tensor_shape.Dimension(12))
+ self.assertEqual(tensor_shape.Dimension(15),
+ dim + tensor_shape.Dimension(3))
+ self.assertEqual(tensor_shape.Dimension(15), dim + 3)
+ self.assertEqual(tensor_shape.Dimension(24),
+ dim * tensor_shape.Dimension(2))
+ self.assertEqual(tensor_shape.Dimension(24), dim * 2)
+ self.assertEqual(tensor_shape.Dimension(6), dim / tensor_shape.Dimension(2))
+ self.assertEqual(tensor_shape.Dimension(6), dim / 2)
+ self.assertEqual(tensor_shape.Dimension(12),
+ dim.merge_with(tensor_shape.Dimension(12)))
+ self.assertEqual(tensor_shape.Dimension(12), dim.merge_with(12))
+ self.assertLess(tensor_shape.Dimension(12), tensor_shape.Dimension(13))
+ self.assertGreater(tensor_shape.Dimension(13), tensor_shape.Dimension(12))
+ self.assertLessEqual(tensor_shape.Dimension(12), tensor_shape.Dimension(12))
+ self.assertLessEqual(tensor_shape.Dimension(12), tensor_shape.Dimension(13))
+ self.assertGreater(tensor_shape.Dimension(13), tensor_shape.Dimension(12))
+ self.assertGreaterEqual(tensor_shape.Dimension(12),
+ tensor_shape.Dimension(12))
+ self.assertGreaterEqual(tensor_shape.Dimension(13),
+ tensor_shape.Dimension(12))
+ with self.assertRaises(ValueError):
+ dim.merge_with(tensor_shape.Dimension(13))
+
+ def testUnknownDimension(self):
+ dim = tensor_shape.Dimension(None)
+ self.assertIs(None, dim.value)
+ self.assertEqual(dim.value, tensor_shape.Dimension(None).value)
+ self.assertEqual(tensor_shape.Dimension(None).value,
+ (dim + tensor_shape.Dimension(None)).value)
+ self.assertEqual(tensor_shape.Dimension(None).value,
+ (dim * tensor_shape.Dimension(None)).value)
+ self.assertEqual(tensor_shape.Dimension(None).value,
+ (dim / tensor_shape.Dimension(None)).value)
+ self.assertEqual(tensor_shape.Dimension(None).value,
+ dim.merge_with(tensor_shape.Dimension(None)).value)
+ self.assertIs(None,
+ tensor_shape.Dimension(None) < tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) <= tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) > tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) >= tensor_shape.Dimension(None))
+
+ def testKnownAndUnknownDimensions(self):
+ known = tensor_shape.Dimension(12)
+ unknown = tensor_shape.Dimension(None)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (known + unknown).value)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (unknown + known).value)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (known * unknown).value)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (unknown * known).value)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (known / unknown).value)
+ self.assertEqual(
+ tensor_shape.Dimension(None).value, (unknown / known).value)
+ self.assertEqual(
+ tensor_shape.Dimension(12), known.merge_with(unknown))
+ self.assertEqual(
+ tensor_shape.Dimension(12), unknown.merge_with(known))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) < tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) <= tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) > tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) >= tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) < tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) <= tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) > tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) >= tensor_shape.Dimension(12))
+
+ def testAsDimension(self):
+ self.assertEqual(tensor_shape.Dimension(12),
+ tensor_shape.as_dimension(tensor_shape.Dimension(12)))
+ self.assertEqual(tensor_shape.Dimension(12), tensor_shape.as_dimension(12))
+ self.assertEqual(
+ tensor_shape.Dimension(None).value,
+ tensor_shape.as_dimension(tensor_shape.Dimension(None)).value)
+ self.assertEqual(tensor_shape.Dimension(None).value,
+ tensor_shape.as_dimension(None).value)
+
+ def testEquality(self):
+ self.assertTrue(tensor_shape.Dimension(12) == tensor_shape.Dimension(12))
+ self.assertFalse(tensor_shape.Dimension(12) == tensor_shape.Dimension(13))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) == tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) == tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) == tensor_shape.Dimension(None))
+
+ def testInequality(self):
+ self.assertTrue(tensor_shape.Dimension(12) != tensor_shape.Dimension(13))
+ self.assertFalse(tensor_shape.Dimension(12) != tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(12) != tensor_shape.Dimension(None))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) != tensor_shape.Dimension(12))
+ self.assertIs(None,
+ tensor_shape.Dimension(None) != tensor_shape.Dimension(None))
+
+
+class ShapeTest(test_util.TensorFlowTestCase):
+
+ def testUnknownShape(self):
+ s = tensor_shape.TensorShape(None)
+ with self.assertRaises(ValueError):
+ s.assert_is_fully_defined()
+ self.assertIs(None, s.ndims)
+ with self.assertRaises(ValueError):
+ len(s)
+ self.assertFalse(s)
+ self.assertIs(None, s.dims)
+
+ def testFullyDefinedShape(self):
+ s = tensor_shape.TensorShape([tensor_shape.Dimension(3),
+ tensor_shape.Dimension(4),
+ tensor_shape.Dimension(7)])
+ s.assert_is_fully_defined()
+ self.assertEqual(3, s.ndims)
+ self.assertEqual(3, len(s))
+ self.assertTrue(s)
+ s.assert_has_rank(3)
+ self.assertEqual([tensor_shape.Dimension(3),
+ tensor_shape.Dimension(4),
+ tensor_shape.Dimension(7)], s.dims)
+ self.assertEqual(tensor_shape.Dimension(3), s[0])
+ self.assertEqual(tensor_shape.Dimension(4), s[1])
+ self.assertEqual(tensor_shape.Dimension(7), s[2])
+ self.assertEqual([3, 4, 7], s.as_list())
+ s.assert_is_compatible_with([3, 4, 7])
+ s.assert_same_rank([6, 3, 7])
+
+ def testPartiallyDefinedShape(self):
+ s = tensor_shape.TensorShape([tensor_shape.Dimension(3),
+ tensor_shape.Dimension(None),
+ tensor_shape.Dimension(7)])
+ with self.assertRaises(ValueError):
+ s.assert_is_fully_defined()
+ self.assertEqual(3, s.ndims)
+ self.assertEqual(3, len(s))
+ self.assertTrue(s)
+ s.assert_has_rank(3)
+ self.assertEqual(tensor_shape.Dimension(3), s[0])
+ self.assertEqual(tensor_shape.Dimension(None).value, s[1].value)
+ self.assertEqual(tensor_shape.Dimension(7), s[2])
+ s.assert_same_rank([6, 3, 7])
+
+ def testMergeFullShapes(self):
+ self.assertEqual([3, 4, 7],
+ tensor_shape.TensorShape([3, 4, 7]).merge_with(
+ tensor_shape.TensorShape([3, 4, 7])).as_list())
+ with self.assertRaises(ValueError):
+ tensor_shape.TensorShape([3, 4, 7]).merge_with(
+ tensor_shape.TensorShape([6, 3, 7]))
+
+ def testMergePartialShapes(self):
+ s1 = tensor_shape.TensorShape([tensor_shape.Dimension(3),
+ tensor_shape.Dimension(None),
+ tensor_shape.Dimension(7)])
+ s2 = tensor_shape.TensorShape([tensor_shape.Dimension(None),
+ tensor_shape.Dimension(4),
+ tensor_shape.Dimension(7)])
+ self.assertEqual([3, 4, 7], s1.merge_with(s2).as_list())
+
+ def testMergeFullAndUnknownShape(self):
+ self.assertEqual([3, 4, 7],
+ tensor_shape.TensorShape([3, 4, 7]).merge_with(
+ tensor_shape.TensorShape(None)).as_list())
+
+ def testSlice(self):
+ known = tensor_shape.TensorShape([0, 1, 2, 3, 4])
+ self.assertEqual(tensor_shape.Dimension(2), known[2])
+ tensor_shape.TensorShape([1, 2, 3]).assert_is_compatible_with(known[1:4])
+
+ unknown = tensor_shape.TensorShape(None)
+ self.assertEqual(tensor_shape.Dimension(None).value, unknown[2].value)
+ tensor_shape.TensorShape(
+ [None, None, None]).assert_is_compatible_with(unknown[1:4])
+
+ def testConcatenate(self):
+ tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with(
+ tensor_shape.TensorShape([1, 2]).concatenate(
+ tensor_shape.TensorShape([3, 4])))
+ tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with(
+ tensor_shape.TensorShape([1, 2]).concatenate(
+ tensor_shape.TensorShape(None)))
+ tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with(
+ tensor_shape.TensorShape(None).concatenate(
+ tensor_shape.TensorShape([3, 4])))
+ tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with(
+ tensor_shape.TensorShape(None).concatenate(
+ tensor_shape.TensorShape(None)))
+ tensor_shape.TensorShape([1, 2, 3]).assert_is_compatible_with(
+ tensor_shape.TensorShape([1, 2]).concatenate(
+ tensor_shape.Dimension(3)))
+
+ def testHelpers(self):
+ tensor_shape.TensorShape([]).assert_is_compatible_with(
+ tensor_shape.scalar())
+ tensor_shape.TensorShape([37]).assert_is_compatible_with(
+ tensor_shape.vector(37))
+ tensor_shape.TensorShape(
+ [94, 43]).assert_is_compatible_with(tensor_shape.matrix(94, 43))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
new file mode 100644
index 0000000000..81ed54c473
--- /dev/null
+++ b/tensorflow/python/framework/tensor_util.py
@@ -0,0 +1,511 @@
+"""Utilities to create TensorProtos."""
+import numbers
+import tensorflow.python.platform
+import numpy as np
+
+from tensorflow.core.framework import tensor_pb2
+from tensorflow.core.framework import tensor_shape_pb2
+
+# TODO(opensource): Add support for pyx_library in the open-source build.
+# For now, we use the slow versions that fast_tensor_util replaces.
+# pylint: disable=g-import-not-at-top
+try:
+ from tensorflow.python.framework import fast_tensor_util
+ _FAST_TENSOR_UTIL_AVAILABLE = True
+except ImportError:
+ _FAST_TENSOR_UTIL_AVAILABLE = False
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+# pylint: enable=g-import-not-at-top
+
+
+if _FAST_TENSOR_UTIL_AVAILABLE:
+ _NP_TO_APPEND_FN = {
+ np.float32: fast_tensor_util.AppendFloat32ArrayToTensorProto,
+ np.float64: fast_tensor_util.AppendFloat64ArrayToTensorProto,
+ np.int32: fast_tensor_util.AppendInt32ArrayToTensorProto,
+ np.int64: fast_tensor_util.AppendInt64ArrayToTensorProto,
+ np.uint8: fast_tensor_util.AppendUInt8ArrayToTensorProto,
+ np.int16: fast_tensor_util.AppendInt16ArrayToTensorProto,
+ np.int8: fast_tensor_util.AppendInt8ArrayToTensorProto,
+ np.complex64: fast_tensor_util.AppendComplex64ArrayToTensorProto,
+ np.complex128: fast_tensor_util.AppendComplex128ArrayToTensorProto,
+ np.object: fast_tensor_util.AppendObjectArrayToTensorProto,
+ np.bool: fast_tensor_util.AppendBoolArrayToTensorProto,
+ types.qint8.as_numpy_dtype:
+ fast_tensor_util.AppendInt8ArrayToTensorProto,
+ types.quint8.as_numpy_dtype:
+ fast_tensor_util.AppendUInt8ArrayToTensorProto,
+ types.qint32.as_numpy_dtype:
+ fast_tensor_util.AppendInt32ArrayToTensorProto,
+ # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16.
+ }
+else:
+
+ def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.float_val.extend([np.asscalar(x) for x in proto_values])
+
+ def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.double_val.extend([np.asscalar(x) for x in proto_values])
+
+ def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.int_val.extend([np.asscalar(x) for x in proto_values])
+
+ def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.int64_val.extend([np.asscalar(x) for x in proto_values])
+
+ def SlowAppendComplexArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.scomplex_val.extend([np.asscalar(v)
+ for x in proto_values
+ for v in [x.real, x.imag]])
+
+ def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.string_val.extend([str(x) for x in proto_values])
+
+ def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.bool_val.extend([np.asscalar(x) for x in proto_values])
+
+ _NP_TO_APPEND_FN = {
+ np.float32: SlowAppendFloat32ArrayToTensorProto,
+ np.float64: SlowAppendFloat64ArrayToTensorProto,
+ np.int32: SlowAppendIntArrayToTensorProto,
+ np.int64: SlowAppendInt64ArrayToTensorProto,
+ np.uint8: SlowAppendIntArrayToTensorProto,
+ np.int16: SlowAppendIntArrayToTensorProto,
+ np.int8: SlowAppendIntArrayToTensorProto,
+ np.complex64: SlowAppendComplexArrayToTensorProto,
+ np.complex128: SlowAppendComplexArrayToTensorProto,
+ np.object: SlowAppendObjectArrayToTensorProto,
+ np.bool: SlowAppendBoolArrayToTensorProto,
+ types.qint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto,
+ types.quint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto,
+ types.qint32.as_numpy_dtype: SlowAppendIntArrayToTensorProto,
+ # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16.
+ }
+
+
+def GetFromNumpyDTypeDict(dtype_dict, dtype):
+ # NOTE: dtype_dict.get(dtype) always returns None.
+ for key, val in dtype_dict.iteritems():
+ if key == dtype:
+ return val
+ return None
+
+
+def GetNumpyAppendFn(dtype):
+ # numpy dtype for strings are variable length. We can not compare
+ # dtype with a single constant (np.string does not exist) to decide
+ # dtype is a "string" type. We need to compare the dtype.type to be
+ # sure it's a string type.
+ if dtype.type == np.string_ or dtype.type == np.unicode_:
+ if _FAST_TENSOR_UTIL_AVAILABLE:
+ return fast_tensor_util.AppendObjectArrayToTensorProto
+ else:
+ return SlowAppendObjectArrayToTensorProto
+ return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype)
+
+
+def MakeTensorShapeProto(shape):
+ """Create a TensorShapeProto.
+
+ Args:
+ shape: List of integers representing the dimensions of the tensor.
+
+ Returns:
+ A TensorShapeProto.
+ """
+ return tensor_shape_pb2.TensorShapeProto(
+ dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=x) for x in shape])
+
+
+def TensorShapeProtoToList(shape):
+ """Convert a TensorShape to a list.
+
+ Args:
+ shape: A TensorShapeProto.
+
+ Returns:
+ List of integers representing the dimensions of the tensor.
+ """
+ return [dim.size for dim in shape.dim]
+
+
+def _GetDenseDimensions(list_of_lists):
+ """Returns the inferred dense dimensions of a list of lists."""
+ if not isinstance(list_of_lists, (list, tuple)):
+ return []
+ elif not list_of_lists:
+ return [0]
+ else:
+ return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0])
+
+
+def _FlattenToStrings(nested_strings):
+ if isinstance(nested_strings, list):
+ for inner in nested_strings:
+ for flattened_string in _FlattenToStrings(inner):
+ yield flattened_string
+ else:
+ yield nested_strings
+
+
+_TENSOR_CONTENT_TYPES = frozenset([
+ types.float32, types.float64, types.int32, types.uint8, types.int16,
+ types.int8, types.int64
+])
+
+
+def _FirstNotNone(l):
+ for x in l:
+ if x is not None:
+ return x
+ return None
+
+
+def _FilterInt(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterInt(x) for x in v])
+ return None if isinstance(v, numbers.Integral) else repr(v)
+
+
+def _FilterFloat(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterFloat(x) for x in v])
+ return None if isinstance(v, numbers.Real) else repr(v)
+
+
+def _FilterComplex(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterComplex(x) for x in v])
+ return None if isinstance(v, numbers.Complex) else repr(v)
+
+
+def _FilterStr(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterStr(x) for x in v])
+ return None if isinstance(v, basestring) else repr(v)
+
+
+def _FilterBool(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterBool(x) for x in v])
+ return None if isinstance(v, bool) else repr(v)
+
+
+def _FilterNotTensor(v):
+ if isinstance(v, (list, tuple)):
+ return _FirstNotNone([_FilterNotTensor(x) for x in v])
+ return repr(v) if isinstance(v, ops.Tensor) else None
+
+
+_TF_TO_IS_OK = {
+ types.float32: _FilterFloat,
+ types.float64: _FilterFloat,
+ types.int32: _FilterInt,
+ types.uint8: _FilterInt,
+ types.int16: _FilterInt,
+ types.int8: _FilterInt,
+ types.string: _FilterStr,
+ types.complex64: _FilterComplex,
+ types.int64: _FilterInt,
+ types.bool: _FilterBool,
+ types.qint32: _FilterInt,
+ types.quint8: _FilterInt,
+ types.qint8: _FilterInt,
+}
+
+
+def _AssertCompatible(values, dtype):
+ fn = _TF_TO_IS_OK.get(dtype, _FilterNotTensor)
+ mismatch = fn(values)
+ if mismatch is not None:
+ if dtype is None:
+ raise TypeError("List of Tensors when single Tensor expected")
+ else:
+ raise TypeError("Expected %s, got %s instead." %
+ (dtype.name, mismatch))
+
+
+def make_tensor_proto(values, dtype=None, shape=None):
+ """Create a TensorProto.
+
+ Args:
+ values: Values to put in the TensorProto.
+ dtype: Optional tensor_pb2 DataType value.
+ shape: List of integers representing the dimensions of tensor.
+
+ Returns:
+ A TensorProto. Depending on the type, it may contain data in the
+ "tensor_content" attribute, which is not directly useful to Python programs.
+ To access the values you should convert the proto back to a numpy ndarray
+ with tensor_util.MakeNdarray(proto).
+
+ Raises:
+ TypeError: if unsupported types are provided.
+ ValueError: if arguments have inappropriate values.
+
+ make_tensor_proto accepts "values" of a python scalar, a python list, a
+ numpy ndarray, or a numpy scalar.
+
+ If "values" is a python scalar or a python list, make_tensor_proto
+ first convert it to numpy ndarray. If dtype is None, the
+ conversion tries its best to infer the right numpy data
+ type. Otherwise, the resulting numpy array has a compatible data
+ type with the given dtype.
+
+ In either case above, the numpy ndarray (either the caller provided
+ or the auto converted) must have the compatible type with dtype.
+
+ make_tensor_proto then converts the numpy array to a tensor proto.
+
+ If "shape" is None, the resulting tensor proto represents the numpy
+ array precisely.
+
+ Otherwise, "shape" specifies the tensor's shape and the numpy array
+ can not have more elements than what "shape" specifies.
+
+ """
+ if dtype:
+ dtype = types.as_dtype(dtype)
+
+ # We first convert value to a numpy array or scalar.
+ if isinstance(values, (np.ndarray, np.generic)):
+ if dtype:
+ nparray = values.astype(dtype.as_numpy_dtype)
+ else:
+ nparray = values
+ else:
+ if values is None:
+ raise ValueError("None values not supported.")
+ # if dtype is provided, forces numpy array to be the type
+ # provided if possible.
+ np_dt = dtype.as_numpy_dtype if dtype else None
+ if np.prod(shape) == 0:
+ nparray = np.empty(shape, dtype=np_dt)
+ else:
+ _AssertCompatible(values, dtype)
+ nparray = np.array(values, dtype=np_dt)
+ if list(nparray.shape) != _GetDenseDimensions(values):
+ raise ValueError("Argument must be a dense tensor: %s" % values)
+ # python/numpy default float type is float64. We prefer float32 instead.
+ if (nparray.dtype == np.float64) and dtype is None:
+ nparray = nparray.astype(np.float32)
+ # python/numpy default int type is int64. We prefer int32 instead.
+ elif (nparray.dtype == np.int64) and dtype is None:
+ nparray = nparray.astype(np.int32)
+
+ # if dtype is provided, it must be compatible with what numpy
+ # conversion says.
+ numpy_dtype = types.as_dtype(nparray.dtype)
+ if numpy_dtype is None:
+ raise TypeError("Unrecognized data type: %s" % nparray.dtype)
+
+ # If dtype was specified and is a quantized type, we convert
+ # numpy_dtype back into the quantized version.
+ if dtype in [types.qint8, types.quint8, types.qint32]:
+ numpy_dtype = dtype
+
+ if dtype is not None and not dtype.base_dtype == numpy_dtype.base_dtype:
+ raise TypeError("Incompatible types: %s vs. %s" % (dtype, nparray.dtype))
+
+ # If shape is not given, get the shape from the numpy array.
+ if shape is None:
+ shape = nparray.shape
+ is_same_size = True
+ shape_size = nparray.size
+ else:
+ shape = [int(dim) for dim in shape]
+ shape_size = np.prod(shape)
+ is_same_size = shape_size == nparray.size
+
+ if nparray.size > shape_size:
+ raise ValueError(
+ "Too many elements provided. Needed at most %d, but received %d" %
+ (shape_size, nparray.size))
+
+ tensor_proto = tensor_pb2.TensorProto(
+ dtype=numpy_dtype.as_datatype_enum,
+ tensor_shape=MakeTensorShapeProto(shape))
+
+ if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
+ tensor_proto.tensor_content = nparray.tostring()
+ return tensor_proto
+
+ # If we were not given values as a numpy array, compute the proto_values
+ # from the given values directly, to avoid numpy trimming nulls from the
+ # strings. Since values could be a list of strings, or a multi-dimensional
+ # list of lists that might or might not correspond to the given shape,
+ # we flatten it conservatively.
+ if numpy_dtype == types.string and not isinstance(values, np.ndarray):
+ proto_values = _FlattenToStrings(values)
+ tensor_proto.string_val.extend([str(x) for x in proto_values])
+ return tensor_proto
+
+ # TensorFlow expects C order (a.k.a., eigen row major).
+ proto_values = nparray.ravel()
+
+ append_fn = GetNumpyAppendFn(proto_values.dtype)
+ if append_fn is None:
+ raise TypeError("Element type not supported in TensorProto: %s" %
+ numpy_dtype.name)
+ append_fn(tensor_proto, proto_values)
+
+ return tensor_proto
+
+
+def MakeNdarray(tensor):
+ """Create a numpy ndarray from a tensor.
+
+ Create a numpy ndarray with the same shape and data as the tensor.
+
+ Args:
+ tensor: A TensorProto.
+
+ Returns:
+ A numpy array with the tensor contents.
+
+ Raises:
+ TypeError: if tensor has unsupported type.
+
+ """
+ shape = [d.size for d in tensor.tensor_shape.dim]
+ num_elements = np.prod(shape)
+ tensor_dtype = types.as_dtype(tensor.dtype)
+ dtype = tensor_dtype.as_numpy_dtype
+
+ if tensor.tensor_content:
+ return np.fromstring(tensor.tensor_content, dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.float32:
+ if len(tensor.float_val) == 1:
+ return np.repeat(np.array(tensor.float_val[0], dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.float64:
+ if len(tensor.double_val) == 1:
+ return np.repeat(np.array(tensor.double_val[0], dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape)
+ elif tensor_dtype in [types.int32, types.uint8, types.int16, types.int8,
+ types.qint32, types.quint8, types.qint8,
+ types.bfloat16]:
+ if len(tensor.int_val) == 1:
+ return np.repeat(np.array(tensor.int_val[0], dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.int64:
+ if len(tensor.int64_val) == 1:
+ return np.repeat(np.array(tensor.int64_val[0], dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.string:
+ if len(tensor.string_val) == 1:
+ return np.repeat(np.array(str(tensor.string_val[0]), dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.array([str(x) for x in tensor.string_val],
+ dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.complex64:
+ it = iter(tensor.scomplex_val)
+ if len(tensor.scomplex_val) == 2:
+ return np.repeat(np.array(complex(tensor.scomplex_val[0],
+ tensor.scomplex_val[1]), dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.array([complex(x[0], x[1]) for x in zip(it, it)],
+ dtype=dtype).reshape(shape)
+ elif tensor_dtype == types.bool:
+ if len(tensor.bool_val) == 1:
+ return np.repeat(np.array(tensor.bool_val[0], dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape)
+ else:
+ raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
+
+
+def ShapeEquals(tensor_proto, shape):
+ """Returns True if "tensor_proto" has the given "shape".
+
+ Args:
+ tensor_proto: A TensorProto.
+ shape: A tensor shape, expressed as a TensorShape, list, or tuple.
+
+ Returns:
+ True if "tensor_proto" has the given "shape", otherwise False.
+
+ Raises:
+ TypeError: If "tensor_proto" is not a TensorProto, or shape is not a
+ TensorShape, list, or tuple.
+ """
+ if not isinstance(tensor_proto, tensor_pb2.TensorProto):
+ raise TypeError("tensor_proto is not a tensor_pb2.TensorProto object")
+ if isinstance(shape, tensor_shape_pb2.TensorShapeProto):
+ shape = [d.size for d in shape.dim]
+ elif not isinstance(shape, (list, tuple)):
+ raise TypeError("shape is not a list or tuple")
+ tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim]
+ return all(x == y for x, y in zip(tensor_shape_list, shape))
+
+
+def ConstantValue(tensor):
+ """Returns the constant value of the given tensor, if efficiently calculable.
+
+ This function attempts to partially evaluate the given tensor, and
+ returns its value as a numpy ndarray if this succeeds.
+
+ TODO(mrry): Consider whether this function should use a registration
+ mechanism like gradients and ShapeFunctions, so that it is easily
+ extensible.
+
+ Args:
+ tensor: The Tensor to be evaluated.
+
+ Returns:
+ A numpy ndarray containing the constant value of the given `tensor`,
+ or None if it cannot be calculated.
+
+ Raises:
+ TypeError: if tensor is not an ops.Tensor.
+ """
+ # TODO(mdevin): Support Variables?
+ if not isinstance(tensor, ops.Tensor):
+ raise TypeError("tensor is not a Tensor")
+ if tensor.op.type == "Const":
+ return MakeNdarray(tensor.op.get_attr("value"))
+ elif tensor.op.type == "Shape":
+ input_shape = tensor.op.inputs[0].get_shape()
+ if input_shape.is_fully_defined():
+ return np.array([dim.value for dim in input_shape.dims])
+ else:
+ return None
+ elif tensor.op.type == "Size":
+ input_shape = tensor.op.inputs[0].get_shape()
+ if input_shape.is_fully_defined():
+ return np.array([np.prod([dim.value for dim in input_shape.dims])])
+ else:
+ return None
+ elif tensor.op.type == "Rank":
+ input_shape = tensor.op.inputs[0].get_shape()
+ if input_shape.ndims is not None:
+ return np.array([input_shape.ndims])
+ else:
+ return None
+ elif tensor.op.type == "Range":
+ start = ConstantValue(tensor.op.inputs[0])
+ if start is None:
+ return None
+ limit = ConstantValue(tensor.op.inputs[1])
+ if limit is None:
+ return None
+ delta = ConstantValue(tensor.op.inputs[2])
+ if delta is None:
+ return None
+ return np.array(range(start, limit, delta),
+ dtype=tensor.dtype.as_numpy_dtype)
+ else:
+ return None
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
new file mode 100644
index 0000000000..7c1c0b8d3e
--- /dev/null
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -0,0 +1,379 @@
+"""Functional tests for tensor_util."""
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import googletest
+
+
+class TensorUtilTest(test_util.TensorFlowTestCase):
+
+ def testFloat(self):
+ t = tensor_util.make_tensor_proto(10.0)
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape {}
+ float_val: 10.0
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array(10.0, dtype=np.float32), a)
+
+ def testFloatN(self):
+ t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0])
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
+
+ def testFloatTyped(self):
+ t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=types.float32)
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
+
+ def testFloatTypeCoerce(self):
+ t = tensor_util.make_tensor_proto([10, 20, 30], dtype=types.float32)
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
+
+ def testFloatTypeCoerceNdarray(self):
+ arr = np.asarray([10, 20, 30], dtype="int")
+ t = tensor_util.make_tensor_proto(arr, dtype=types.float32)
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
+
+ def testFloatSizes(self):
+ t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[1, 3])
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([[10.0, 20.0, 30.0]], dtype=np.float32), a)
+
+ def testFloatSizes2(self):
+ t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[3, 1])
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 3 } dim { size: 1 } }
+ tensor_content: "\000\000 A\000\000\240A\000\000\360A"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float32, a.dtype)
+ self.assertAllClose(np.array([[10.0], [20.0], [30.0]], dtype=np.float32),
+ a)
+
+ def testFloatSizesLessValues(self):
+ t = tensor_util.make_tensor_proto(10.0, shape=[1, 3])
+ self.assertProtoEquals("""
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ float_val: 10.0
+ """, t)
+ # No conversion to Ndarray for this one: not enough values.
+
+ def testFloatNpArrayFloat64(self):
+ t = tensor_util.make_tensor_proto(
+ np.array([[10.0, 20.0, 30.0]], dtype=np.float64))
+ self.assertProtoEquals("""
+ dtype: DT_DOUBLE
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ tensor_content: "\000\000\000\000\000\000$@\000\000\000\000\000\0004@\000\000\000\000\000\000>@"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.float64, a.dtype)
+ self.assertAllClose(np.array([[10.0, 20.0, 30.0]], dtype=np.float64),
+ tensor_util.MakeNdarray(t))
+
+ def testFloatTypesWithImplicitRepeat(self):
+ for dtype, nptype in [
+ (types.float32, np.float32), (types.float64, np.float64)]:
+ t = tensor_util.make_tensor_proto([10.0], shape=[3, 4], dtype=dtype)
+ a = tensor_util.MakeNdarray(t)
+ self.assertAllClose(np.array([[10.0, 10.0, 10.0, 10.0],
+ [10.0, 10.0, 10.0, 10.0],
+ [10.0, 10.0, 10.0, 10.0]], dtype=nptype), a)
+
+ def testInt(self):
+ t = tensor_util.make_tensor_proto(10)
+ self.assertProtoEquals("""
+ dtype: DT_INT32
+ tensor_shape {}
+ int_val: 10
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.int32, a.dtype)
+ self.assertAllClose(np.array(10, dtype=np.int32), a)
+
+ def testIntNDefaultType(self):
+ t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2])
+ self.assertProtoEquals("""
+ dtype: DT_INT32
+ tensor_shape { dim { size: 2 } dim { size: 2 } }
+ tensor_content: "\\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.int32, a.dtype)
+ self.assertAllClose(np.array([[10, 20], [30, 40]], dtype=np.int32), a)
+
+ def testIntTypes(self):
+ for dtype, nptype in [
+ (types.int32, np.int32),
+ (types.uint8, np.uint8),
+ (types.int16, np.int16),
+ (types.int8, np.int8)]:
+ # Test with array.
+ t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype)
+ self.assertEquals(dtype, t.dtype)
+ self.assertProtoEquals("dim { size: 3 }", t.tensor_shape)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(nptype, a.dtype)
+ self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
+ # Test with ndarray.
+ t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype))
+ self.assertEquals(dtype, t.dtype)
+ self.assertProtoEquals("dim { size: 3 }", t.tensor_shape)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(nptype, a.dtype)
+ self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
+
+ def testIntTypesWithImplicitRepeat(self):
+ for dtype, nptype in [
+ (types.int64, np.int64),
+ (types.int32, np.int32),
+ (types.uint8, np.uint8),
+ (types.int16, np.int16),
+ (types.int8, np.int8)]:
+ t = tensor_util.make_tensor_proto([10], shape=[3, 4], dtype=dtype)
+ a = tensor_util.MakeNdarray(t)
+ self.assertAllEqual(np.array([[10, 10, 10, 10],
+ [10, 10, 10, 10],
+ [10, 10, 10, 10]], dtype=nptype), a)
+
+ def testLong(self):
+ t = tensor_util.make_tensor_proto(10, dtype=types.int64)
+ self.assertProtoEquals("""
+ dtype: DT_INT64
+ tensor_shape {}
+ int64_val: 10
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.int64, a.dtype)
+ self.assertAllClose(np.array(10, dtype=np.int64), a)
+
+ def testLongN(self):
+ t = tensor_util.make_tensor_proto([10, 20, 30], shape=[1, 3],
+ dtype=types.int64)
+ self.assertProtoEquals("""
+ dtype: DT_INT64
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.int64, a.dtype)
+ self.assertAllClose(np.array([[10, 20, 30]], dtype=np.int64), a)
+
+ def testLongNpArray(self):
+ t = tensor_util.make_tensor_proto(np.array([10, 20, 30]))
+ self.assertProtoEquals("""
+ dtype: DT_INT64
+ tensor_shape { dim { size: 3 } }
+ tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.int64, a.dtype)
+ self.assertAllClose(np.array([10, 20, 30], dtype=np.int64), a)
+
+ def testString(self):
+ t = tensor_util.make_tensor_proto("foo")
+ self.assertProtoEquals("""
+ dtype: DT_STRING
+ tensor_shape {}
+ string_val: "foo"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.object, a.dtype)
+ self.assertEquals(["foo"], a)
+
+ def testStringWithImplicitRepeat(self):
+ t = tensor_util.make_tensor_proto("f", shape=[3, 4])
+ a = tensor_util.MakeNdarray(t)
+ self.assertAllEqual(np.array([["f", "f", "f", "f"],
+ ["f", "f", "f", "f"],
+ ["f", "f", "f", "f"]], dtype=np.object), a)
+
+ def testStringN(self):
+ t = tensor_util.make_tensor_proto(["foo", "bar", "baz"], shape=[1, 3])
+ self.assertProtoEquals("""
+ dtype: DT_STRING
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ string_val: "foo"
+ string_val: "bar"
+ string_val: "baz"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.object, a.dtype)
+ self.assertAllEqual(np.array([["foo", "bar", "baz"]]), a)
+
+ def testStringNpArray(self):
+ t = tensor_util.make_tensor_proto(np.array([["a", "ab"], ["abc", "abcd"]]))
+ self.assertProtoEquals("""
+ dtype: DT_STRING
+ tensor_shape { dim { size: 2 } dim { size: 2 } }
+ string_val: "a"
+ string_val: "ab"
+ string_val: "abc"
+ string_val: "abcd"
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.object, a.dtype)
+ self.assertAllEqual(np.array([["a", "ab"], ["abc", "abcd"]]), a)
+
+ def testComplex(self):
+ t = tensor_util.make_tensor_proto((1+2j), dtype=types.complex64)
+ self.assertProtoEquals("""
+ dtype: DT_COMPLEX64
+ tensor_shape {}
+ scomplex_val: 1
+ scomplex_val: 2
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.complex64, a.dtype)
+ self.assertAllEqual(np.array(1 + 2j), a)
+
+ def testComplexWithImplicitRepeat(self):
+ t = tensor_util.make_tensor_proto((1+1j), shape=[3, 4],
+ dtype=types.complex64)
+ a = tensor_util.MakeNdarray(t)
+ self.assertAllClose(np.array([[(1+1j), (1+1j), (1+1j), (1+1j)],
+ [(1+1j), (1+1j), (1+1j), (1+1j)],
+ [(1+1j), (1+1j), (1+1j), (1+1j)]],
+ dtype=np.complex64), a)
+
+ def testComplexN(self):
+ t = tensor_util.make_tensor_proto([(1+2j), (3+4j), (5+6j)], shape=[1, 3],
+ dtype=types.complex64)
+ self.assertProtoEquals("""
+ dtype: DT_COMPLEX64
+ tensor_shape { dim { size: 1 } dim { size: 3 } }
+ scomplex_val: 1
+ scomplex_val: 2
+ scomplex_val: 3
+ scomplex_val: 4
+ scomplex_val: 5
+ scomplex_val: 6
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.complex64, a.dtype)
+ self.assertAllEqual(np.array([[(1+2j), (3+4j), (5+6j)]]), a)
+
+ def testComplexNpArray(self):
+ t = tensor_util.make_tensor_proto(
+ np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), dtype=types.complex64)
+ # scomplex_val are real_0, imag_0, real_1, imag_1, ...
+ self.assertProtoEquals("""
+ dtype: DT_COMPLEX64
+ tensor_shape { dim { size: 2 } dim { size: 2 } }
+ scomplex_val: 1
+ scomplex_val: 2
+ scomplex_val: 3
+ scomplex_val: 4
+ scomplex_val: 5
+ scomplex_val: 6
+ scomplex_val: 7
+ scomplex_val: 8
+ """, t)
+ a = tensor_util.MakeNdarray(t)
+ self.assertEquals(np.complex64, a.dtype)
+ self.assertAllEqual(np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), a)
+
+ def testUnsupportedDType(self):
+ with self.assertRaises(TypeError):
+ tensor_util.make_tensor_proto(np.array([1]), 0)
+
+ def testShapeTooLarge(self):
+ with self.assertRaises(ValueError):
+ tensor_util.make_tensor_proto(np.array([1, 2]), shape=[1])
+
+ def testLowRankSupported(self):
+ t = tensor_util.make_tensor_proto(np.array(7))
+ self.assertProtoEquals("""
+ dtype: DT_INT64
+ tensor_shape {}
+ int64_val: 7
+ """, t)
+
+ def testShapeEquals(self):
+ t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2])
+ self.assertTrue(tensor_util.ShapeEquals(t, [2, 2]))
+ self.assertTrue(tensor_util.ShapeEquals(t, (2, 2)))
+ self.assertTrue(
+ tensor_util.ShapeEquals(t, tensor_util.MakeTensorShapeProto([2, 2])))
+ self.assertFalse(tensor_util.ShapeEquals(t, [5, 3]))
+ self.assertFalse(tensor_util.ShapeEquals(t, [1, 4]))
+ self.assertFalse(tensor_util.ShapeEquals(t, [4]))
+
+
+class ConstantValueTest(test_util.TensorFlowTestCase):
+
+ def testConstant(self):
+ np_val = np.random.rand(3, 4, 7).astype(np.float32)
+ tf_val = constant_op.constant(np_val)
+ self.assertAllClose(np_val, tensor_util.ConstantValue(tf_val))
+
+ np_val = np.random.rand(3, 0, 7).astype(np.float32)
+ tf_val = constant_op.constant(np_val)
+ self.assertAllClose(np_val, tensor_util.ConstantValue(tf_val))
+
+ def testUnknown(self):
+ tf_val = state_ops.variable_op(shape=[3, 4, 7], dtype=types.float32)
+ self.assertIs(None, tensor_util.ConstantValue(tf_val))
+
+ def testShape(self):
+ np_val = np.array([1, 2, 3])
+ tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3]))
+ self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val))
+
+ def testSize(self):
+ np_val = np.array([6])
+ tf_val = array_ops.size(constant_op.constant(0.0, shape=[1, 2, 3]))
+ self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val))
+
+ def testRank(self):
+ np_val = np.array([3])
+ tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3]))
+ self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/framework/test_kernel_label_op.cc b/tensorflow/python/framework/test_kernel_label_op.cc
new file mode 100644
index 0000000000..50f8522e1b
--- /dev/null
+++ b/tensorflow/python/framework/test_kernel_label_op.cc
@@ -0,0 +1,47 @@
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("KernelLabel").Output("result: string");
+
+namespace {
+enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL };
+} // namespace
+
+template <KernelLabel KL>
+class KernelLabelOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor* output;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output("result", TensorShape({}), &output));
+ switch (KL) {
+ case DEFAULT_LABEL:
+ output->scalar<string>()() = "My label is: default";
+ break;
+ case OVERLOAD_1_LABEL:
+ output->scalar<string>()() = "My label is: overload_1";
+ break;
+ case OVERLOAD_2_LABEL:
+ output->scalar<string>()() = "My label is: overload_2";
+ break;
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("KernelLabel").Device(DEVICE_CPU),
+ KernelLabelOp<DEFAULT_LABEL>);
+REGISTER_KERNEL_BUILDER(Name("KernelLabel")
+ .Device(DEVICE_CPU)
+ .Label("overload_1"),
+ KernelLabelOp<OVERLOAD_1_LABEL>);
+REGISTER_KERNEL_BUILDER(Name("KernelLabel")
+ .Device(DEVICE_CPU)
+ .Label("overload_2"),
+ KernelLabelOp<OVERLOAD_2_LABEL>);
+
+} // end namespace tensorflow
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
new file mode 100644
index 0000000000..597a5ad829
--- /dev/null
+++ b/tensorflow/python/framework/test_util.py
@@ -0,0 +1,437 @@
+# pylint: disable=invalid-name
+"""Test utils for tensorflow."""
+import contextlib
+import math
+import re
+import threading
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import config_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.client import graph_util
+from tensorflow.python.client import session
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import logging
+from tensorflow.python.util.protobuf import compare
+
+
+def IsGoogleCudaEnabled():
+ return pywrap_tensorflow.IsGoogleCudaEnabled()
+
+
+class TensorFlowTestCase(googletest.TestCase):
+ """Root class for tests that need to test tensor flow.
+ """
+
+ def __init__(self, methodName="runTest"):
+ super(TensorFlowTestCase, self).__init__(methodName)
+ self._threads = []
+ self._tempdir = None
+ self._cached_session = None
+
+ def setUp(self):
+ self._ClearCachedSession()
+ ops.reset_default_graph()
+
+ def tearDown(self):
+ for thread in self._threads:
+ self.assertFalse(thread.is_alive(), "A checkedThread did not terminate")
+ self._ClearCachedSession()
+
+ def _ClearCachedSession(self):
+ if self._cached_session is not None:
+ self._cached_session.close()
+ self._cached_session = None
+
+ def get_temp_dir(self):
+ if not self._tempdir:
+ self._tempdir = googletest.GetTempDir()
+ return self._tempdir
+
+ def _AssertProtoEquals(self, a, b):
+ """Asserts that a and b are the same proto.
+
+ Uses Proto2Cmp() first, as it returns correct results
+ for floating point attributes, and then use assertProto2Equal()
+ in case of failure as it provides good error messages.
+
+ Args:
+ a: a proto.
+ b: another proto.
+ """
+ if compare.Proto2Cmp(a, b) != 0:
+ compare.assertProto2Equal(self, a, b, normalize_numbers=True)
+
+ def assertProtoEquals(self, expected_message_maybe_ascii, message):
+ """Asserts that message is same as parsed expected_message_ascii.
+
+ Creates another prototype of message, reads the ascii message into it and
+ then compares them using self._AssertProtoEqual().
+
+ Args:
+ expected_message_maybe_ascii: proto message in original or ascii form
+ message: the message to validate
+ """
+
+ if type(expected_message_maybe_ascii) == type(message):
+ expected_message = expected_message_maybe_ascii
+ self._AssertProtoEquals(expected_message, message)
+ elif isinstance(expected_message_maybe_ascii, str):
+ expected_message = type(message)()
+ text_format.Merge(expected_message_maybe_ascii, expected_message)
+ self._AssertProtoEquals(expected_message, message)
+ else:
+ assert False, ("Can't compare protos of type " +
+ type(expected_message_maybe_ascii) + " and " +
+ type(message))
+
+ def assertStartsWith(self, actual, expected_start, msg=None):
+ """Assert that actual.startswith(expected_start) is True.
+
+ Args:
+ actual: str
+ expected_start: str
+ msg: Optional message to report on failure.
+ """
+ if not actual.startswith(expected_start):
+ fail_msg = "%r does not start with %r" % (actual, expected_start)
+ fail_msg += " : %r" % (msg) if msg else ""
+ self.fail(fail_msg)
+
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
+ def test_session(self,
+ graph=None,
+ config=None,
+ use_gpu=False,
+ force_gpu=False):
+ """Returns a TensorFlow Session for use in executing tests.
+
+ This method should be used for all functional tests.
+
+ Use the `use_gpu` and `force_gpu` options to control where ops are run. If
+ `force_gpu` is True, all ops are pinned to `/gpu:0`. Otherwise, if `use_gpu`
+ is True, TensorFlow tries to run as many ops on the GPU as possible. If both
+ `force_gpu and `use_gpu` are False, all ops are pinned to the CPU.
+
+ Example:
+
+ class MyOperatorTest(test_util.TensorFlowTestCase):
+ def testMyOperator(self):
+ with self.test_session(use_gpu=True):
+ valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
+ result = MyOperator(valid_input).eval()
+ self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
+ invalid_input = [-1.0, 2.0, 7.0]
+ with self.assertRaisesOpError("negative input not supported"):
+ MyOperator(invalid_input).eval()
+
+ Args:
+ graph: Optional graph to use during the returned session.
+ config: An optional config_pb2.ConfigProto to use to configure the
+ session.
+ use_gpu: If True, attempt to run as many ops as possible on GPU.
+ force_gpu: If True, pin all ops to `/gpu:0`.
+
+ Returns:
+ A Session object that should be used as a context manager to surround
+ the graph building and execution code in a test case.
+ """
+ def prepare_config(config):
+ if config is None:
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = not force_gpu
+ config.gpu_options.per_process_gpu_memory_fraction = 0.3
+ elif force_gpu and config.allow_soft_placement:
+ config = config_pb2.ConfigProto().CopyFrom(config)
+ config.allow_soft_placement = False
+ return config
+
+ if graph is None:
+ if self._cached_session is None:
+ self._cached_session = session.Session(graph=None,
+ config=prepare_config(config))
+ sess = self._cached_session
+ with sess.graph.as_default(), sess.as_default():
+ if force_gpu:
+ with sess.graph.device("/gpu:0"):
+ yield sess
+ elif use_gpu:
+ yield sess
+ else:
+ with sess.graph.device(graph_util.pin_to_cpu):
+ yield sess
+ else:
+ with session.Session(graph=graph, config=prepare_config(config)) as sess:
+ if force_gpu:
+ with sess.graph.device("/gpu:0"):
+ yield sess
+ elif use_gpu:
+ yield sess
+ else:
+ with sess.graph.device(graph_util.pin_to_cpu):
+ yield sess
+ # pylint: enable=g-doc-return-or-yield
+
+ class _CheckedThread(object):
+ """A wrapper class for Thread that asserts successful completion.
+
+ This class should be created using the TensorFlowTestCase.checkedThread()
+ method.
+ """
+
+ def __init__(self, testcase, target, args=None, kwargs=None):
+ """Constructs a new instance of _CheckedThread.
+
+ Args:
+ testcase: The TensorFlowTestCase for which this thread is being created.
+ target: A callable object representing the code to be executed in the
+ thread.
+ args: A tuple of positional arguments that will be passed to target.
+ kwargs: A dictionary of keyword arguments that will be passed to target.
+ """
+ self._testcase = testcase
+ self._target = target
+ self._args = () if args is None else args
+ self._kwargs = {} if kwargs is None else kwargs
+ self._thread = threading.Thread(target=self._protected_run)
+ self._exception = None
+
+ def _protected_run(self):
+ """Target for the wrapper thread. Sets self._exception on failure."""
+ try:
+ self._target(*self._args, **self._kwargs)
+# pylint: disable=broad-except
+ except Exception as e:
+ # pylint: enable=broad-except
+ self._exception = e
+
+ def start(self):
+ """Starts the thread's activity.
+
+ This must be called at most once per _CheckedThread object. It arranges
+ for the object's target to be invoked in a separate thread of control.
+ """
+ self._thread.start()
+
+ def join(self):
+ """Blocks until the thread terminates.
+
+ Raises:
+ self._testcase.failureException: If the thread terminates with due to
+ an exception.
+ """
+ self._thread.join()
+ if self._exception is not None:
+ self._testcase.fail(
+ "Error in checkedThread: %s" % str(self._exception))
+
+ def is_alive(self):
+ """Returns whether the thread is alive.
+
+ This method returns True just before the run() method starts
+ until just after the run() method terminates.
+
+ Returns:
+ True if the thread is alive, otherwise False.
+ """
+ return self._thread.is_alive()
+
+ def checkedThread(self, target, args=None, kwargs=None):
+ """Returns a Thread wrapper that asserts 'target' completes successfully.
+
+ This method should be used to create all threads in test cases, as
+ otherwise there is a risk that a thread will silently fail, and/or
+ assertions made in the thread will not be respected.
+
+ Args:
+ target: A callable object to be executed in the thread.
+ args: The argument tuple for the target invocation. Defaults to ().
+ kwargs: A dictionary of keyword arguments for the target invocation.
+ Defaults to {}.
+
+ Returns:
+ A wrapper for threading.Thread that supports start() and join() methods.
+ """
+ ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs)
+ self._threads.append(ret)
+ return ret
+# pylint: enable=invalid-name
+
+ def assertNear(self, f1, f2, err):
+ """Asserts that two floats are near each other.
+
+ Checks that |f1 - f2| < err and asserts a test failure
+ if not.
+
+ Args:
+ f1: a float value.
+ f2: a float value.
+ err: a float value.
+ """
+ self.assertTrue(math.fabs(f1 - f2) < err)
+
+ def assertArrayNear(self, farray1, farray2, err):
+ """Asserts that two float arrays are near each other.
+
+ Checks that for all elements of farray1 and farray2
+ |f1 - f2| < err. Asserts a test failure if not.
+
+ Args:
+ farray1: a list of float values.
+ farray2: a list of float values.
+ err: a float value.
+ """
+ for f1, f2 in zip(farray1, farray2):
+ self.assertNear(f1, f2, err)
+
+ def _NDArrayNear(self, ndarray1, ndarray2, err):
+ return np.linalg.norm(ndarray1 - ndarray2) < err
+
+ def assertNDArrayNear(self, ndarray1, ndarray2, err):
+ """Asserts that two numpy arrays have near values.
+
+ Args:
+ ndarray1: a numpy ndarray.
+ ndarray2: a numpy ndarray.
+ err: a float. The maximum absolute difference allowed.
+ """
+ self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err))
+
+ def _GetNdArray(self, a):
+ if not isinstance(a, np.ndarray):
+ a = np.array(a)
+ return a
+
+ def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
+ """Asserts that two numpy arrays have near values.
+
+ Args:
+ a: a numpy ndarray or anything can be converted to one.
+ b: a numpy ndarray or anything can be converted to one.
+ rtol: relative tolerance
+ atol: absolute tolerance
+ """
+ a = self._GetNdArray(a)
+ b = self._GetNdArray(b)
+ self.assertEqual(
+ a.shape, b.shape,
+ "Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
+ if not np.allclose(a, b, rtol=rtol, atol=atol):
+ # Prints more details than np.testing.assert_allclose.
+ #
+ # NOTE: numpy.allclose (and numpy.testing.assert_allclose)
+ # checks whether two arrays are element-wise equal within a
+ # tolerance. The relative difference (rtol * abs(b)) and the
+ # absolute difference atol are added together to compare against
+ # the absolute difference between a and b. Here, we want to
+ # print out which elements violate such conditions.
+ cond = np.abs(a - b) > atol + rtol * np.abs(b)
+ if a.ndim:
+ x = a[np.where(cond)]
+ y = b[np.where(cond)]
+ print "not close where = ", np.where(cond)
+ else:
+ # np.where is broken for scalars
+ x, y = a, b
+ print "not close lhs = ", x
+ print "not close rhs = ", y
+ print "not close dif = ", np.abs(x - y)
+ print "not close tol = ", atol + rtol * np.abs(y)
+ np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
+
+ def assertAllEqual(self, a, b):
+ """Asserts that two numpy arrays have the same values.
+
+ Args:
+ a: a numpy ndarray or anything can be converted to one.
+ b: a numpy ndarray or anything can be converted to one.
+ """
+ a = self._GetNdArray(a)
+ b = self._GetNdArray(b)
+ self.assertEqual(
+ a.shape, b.shape,
+ "Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
+ same = (a == b)
+
+ if a.dtype == np.float32 or a.dtype == np.float64:
+ same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
+ if not np.all(same):
+ # Prints more details than np.testing.assert_array_equal.
+ diff = np.logical_not(same)
+ if a.ndim:
+ x = a[np.where(diff)]
+ y = b[np.where(diff)]
+ print "not equal where = ", np.where(diff)
+ else:
+ # np.where is broken for scalars
+ x, y = a, b
+ print "not equal lhs = ", x
+ print "not equal rhs = ", y
+ np.testing.assert_array_equal(a, b)
+
+ # pylint: disable=g-doc-return-or-yield
+ @contextlib.contextmanager
+ def assertRaisesWithPredicateMatch(self, exception_type,
+ expected_err_re_or_predicate):
+ """Returns a context manager to enclose code expected to raise an exception.
+
+ Args:
+ exception_type: The expected type of exception that should be raised.
+ expected_err_re_or_predicate: If this is callable, it should be a function
+ of one argument that inspects the passed-in OpError exception and
+ returns True (success) or False (please fail the test). Otherwise, the
+ error message is expected to match this regular expression partially.
+
+ Returns:
+ A context manager to surround code that is expected to raise an
+ errors.OpError exception.
+ """
+ if callable(expected_err_re_or_predicate):
+ predicate = expected_err_re_or_predicate
+ else:
+ def predicate(e):
+ err_str = e.message
+ op = e.op
+ while op is not None:
+ err_str += "\nCaused by: " + op.name
+ op = op._original_op
+ logging.info("Searching within error strings: '%s' within '%s'",
+ expected_err_re_or_predicate, err_str)
+ return re.search(expected_err_re_or_predicate, err_str)
+ try:
+ yield
+ self.fail(exception_type.__name__ + " not raised")
+# pylint: disable=broad-except
+ except Exception as e:
+ # pylint: enable=broad-except
+ if not isinstance(e, exception_type) or not predicate(e):
+ raise AssertionError(e)
+ # pylint: enable=g-doc-return-or-yield
+
+ def assertRaisesOpError(self, expected_err_re_or_predicate):
+ return self.assertRaisesWithPredicateMatch(errors.OpError,
+ expected_err_re_or_predicate)
+
+ def assertShapeEqual(self, np_array, tf_tensor):
+ """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape.
+
+ Args:
+ np_array: A Numpy ndarray or Numpy scalar.
+ tf_tensor: A Tensor.
+
+ Raises:
+ TypeError: If the arguments have the wrong type.
+ """
+ if not isinstance(np_array, (np.ndarray, np.generic)):
+ raise TypeError("np_array must be a Numpy ndarray or Numpy scalar")
+ if not isinstance(tf_tensor, ops.Tensor):
+ raise TypeError("tf_tensor must be a Tensor")
+ self.assertAllEqual(np_array.shape, tf_tensor.get_shape().as_list())
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
new file mode 100644
index 0000000000..e0618cfea4
--- /dev/null
+++ b/tensorflow/python/framework/test_util_test.py
@@ -0,0 +1,128 @@
+"""Tests for tensorflow.ops.test_util."""
+import threading
+
+import tensorflow.python.platform
+import numpy as np
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.platform import googletest
+from tensorflow.python.ops import logging_ops
+
+class TestUtilTest(test_util.TensorFlowTestCase):
+
+ def testIsGoogleCudaEnabled(self):
+ # The test doesn't assert anything. It ensures the py wrapper
+ # function is generated correctly.
+ if test_util.IsGoogleCudaEnabled():
+ print "GoogleCuda is enabled"
+ else:
+ print "GoogleCuda is disabled"
+
+ def testAssertProtoEqualsStr(self):
+
+ graph_str = "node { name: 'w1' op: 'params' }"
+ graph_def = graph_pb2.GraphDef()
+ text_format.Merge(graph_str, graph_def)
+
+ # test string based comparison
+ self.assertProtoEquals(graph_str, graph_def)
+
+ # test original comparison
+ self.assertProtoEquals(graph_def, graph_def)
+
+ def testNDArrayNear(self):
+ a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ a3 = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])
+ self.assertTrue(self._NDArrayNear(a1, a2, 1e-5))
+ self.assertFalse(self._NDArrayNear(a1, a3, 1e-5))
+
+ def testCheckedThreadSucceeds(self):
+ def noop(ev):
+ ev.set()
+
+ event_arg = threading.Event()
+
+ self.assertFalse(event_arg.is_set())
+ t = self.checkedThread(target=noop, args=(event_arg,))
+ t.start()
+ t.join()
+ self.assertTrue(event_arg.is_set())
+
+ def testCheckedThreadFails(self):
+ def err_func():
+ return 1 / 0
+
+ t = self.checkedThread(target=err_func)
+ t.start()
+ with self.assertRaises(self.failureException) as fe:
+ t.join()
+ self.assertTrue("integer division or modulo by zero"
+ in fe.exception.message)
+
+ def testCheckedThreadWithWrongAssertionFails(self):
+ x = 37
+
+ def err_func():
+ self.assertTrue(x < 10)
+
+ t = self.checkedThread(target=err_func)
+ t.start()
+ with self.assertRaises(self.failureException) as fe:
+ t.join()
+ self.assertTrue("False is not true" in fe.exception.message)
+
+ def testMultipleThreadsWithOneFailure(self):
+ def err_func(i):
+ self.assertTrue(i != 7)
+
+ threads = [self.checkedThread(target=err_func, args=(i,))
+ for i in range(10)]
+ for t in threads:
+ t.start()
+ for i, t in enumerate(threads):
+ if i == 7:
+ with self.assertRaises(self.failureException):
+ t.join()
+ else:
+ t.join()
+
+ def _WeMustGoDeeper(self, msg):
+ with self.assertRaisesOpError(msg):
+ node_def = ops._NodeDef("op_type", "name")
+ node_def_orig = ops._NodeDef("op_type_orig", "orig")
+ op_orig = ops.Operation(node_def_orig, ops.get_default_graph())
+ op = ops.Operation(node_def, ops.get_default_graph(), original_op=op_orig)
+ raise errors.UnauthenticatedError(node_def, op, "true_err")
+
+ def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self):
+ with self.assertRaises(AssertionError):
+ self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for")
+
+ self._WeMustGoDeeper("true_err")
+ self._WeMustGoDeeper("name")
+ self._WeMustGoDeeper("orig")
+
+ def testAllCloseScalars(self):
+ self.assertAllClose(7, 7 + 1e-8)
+ with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
+ self.assertAllClose(7, 8)
+
+ def testForceGPU(self):
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Cannot assign a device to node"):
+ with self.test_session(force_gpu=True):
+ # this relies on us not having a GPU implementation for assert, which
+ # seems sensible
+ x = [True]
+ y = [15]
+ logging_ops.Assert(x, y).run()
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/framework/types.py b/tensorflow/python/framework/types.py
new file mode 100644
index 0000000000..6a8c629fe4
--- /dev/null
+++ b/tensorflow/python/framework/types.py
@@ -0,0 +1,418 @@
+"""Library of dtypes (Tensor element types)."""
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.core.framework import types_pb2
+
+
+class DType(object):
+ """Represents the type of the elements in a `Tensor`.
+
+ The following `DType` objects are defined:
+
+ * `tf.float32`: 32-bit single-precision floating-point.
+ * `tf.float64`: 64-bit double-precision floating-point.
+ * `tf.bfloat16`: 16-bit truncated floating-point.
+ * `tf.complex64`: 64-bit single-precision complex.
+
+ * `tf.int8`: 8-bit signed integer.
+ * `tf.uint8`: 8-bit unsigned integer.
+ * `tf.int32`: 32-bit signed integer.
+ * `tf.int64`: 64-bit signed integer.
+
+ * `tf.bool`: Boolean.
+
+ * `tf.string`: String.
+
+ * `tf.qint8`: Quantized 8-bit signed integer.
+ * `tf.quint8`: Quantized 8-bit unsigned integer.
+ * `tf.qint32`: Quantized 32-bit signed integer.
+
+ In addition, variants of these types with the `_ref` suffix are
+ defined for reference-typed tensors.
+
+ The `tf.as_dtype()` function converts numpy types and string type
+ names to a `DType` object.
+
+ @@is_compatible_with
+ @@name
+ @@base_dtype
+ @@is_ref_dtype
+ @@as_ref
+ @@is_integer
+ @@is_quantized
+
+ @@as_numpy_dtype
+ @@as_datatype_enum
+ """
+
+ def __init__(self, type_enum):
+ """Creates a new `DataType`.
+
+ NOTE(mrry): In normal circumstances, you should not need to
+ construct a DataType object directly. Instead, use the
+ types.as_dtype() function.
+
+ Args:
+ type_enum: A `types_pb2.DataType` enum value.
+
+ Raises:
+ TypeError: If `type_enum` is not a value `types_pb2.DataType`.
+
+ """
+ # TODO(mrry): Make the necessary changes (using __new__) to ensure
+ # that calling this returns one of the interned values.
+ type_enum = int(type_enum)
+ if (type_enum not in types_pb2.DataType.values()
+ or type_enum == types_pb2.DT_INVALID):
+ raise TypeError(
+ "type_enum is not a valid types_pb2.DataType: %s" % type_enum)
+ self._type_enum = type_enum
+
+ @property
+ def is_ref_dtype(self):
+ """Returns `True` if this `DType` represents a reference type."""
+ return self._type_enum > 100
+
+ @property
+ def as_ref(self):
+ """Returns a reference `DType` based on this `DType`."""
+ if self.is_ref_dtype:
+ return self
+ else:
+ return _INTERN_TABLE[self._type_enum + 100]
+
+ @property
+ def base_dtype(self):
+ """Returns a non-reference `DType` based on this `DType`."""
+ if self.is_ref_dtype:
+ return _INTERN_TABLE[self._type_enum - 100]
+ else:
+ return self
+
+ @property
+ def as_numpy_dtype(self):
+ """Returns a `numpy.dtype` based on this `DType`."""
+ return _TF_TO_NP[self._type_enum]
+
+ @property
+ def as_datatype_enum(self):
+ """Returns a `types_pb2.DataType` enum value based on this `DType`."""
+ return self._type_enum
+
+ @property
+ def is_integer(self):
+ """Returns whether this is a (non-quantized) integer type."""
+ return (not self.is_quantized and
+ issubclass(self.as_numpy_dtype, np.integer))
+
+ @property
+ def is_quantized(self):
+ """Returns whether this is a quantized data type."""
+ return self.base_dtype in [qint8, quint8, qint32, bfloat16]
+
+ @property
+ def min(self):
+ """Returns the minimum representable value in this data type.
+
+ Raises:
+ TypeError: if this is a non-numeric, unordered, or quantized type.
+
+ """
+ if (self.is_quantized or self.base_dtype == bool or
+ self.base_dtype == string or self.base_dtype == complex64):
+ raise TypeError("Cannot find minimum value of %s." % self)
+
+ # there is no simple way to get the min value of a dtype, we have to check
+ # float and int types separately
+ try:
+ return np.finfo(self.as_numpy_dtype()).min
+ except: # bare except as possible raises by finfo not documented
+ try:
+ return np.iinfo(self.as_numpy_dtype()).min
+ except:
+ raise TypeError("Cannot find minimum value of %s." % self)
+
+ @property
+ def max(self):
+ """Returns the maximum representable value in this data type.
+
+ Raises:
+ TypeError: if this is a non-numeric, unordered, or quantized type.
+
+ """
+ if (self.is_quantized or self.base_dtype == bool or
+ self.base_dtype == string or self.base_dtype == complex64):
+ raise TypeError("Cannot find maximum value of %s." % self)
+
+ # there is no simple way to get the min value of a dtype, we have to check
+ # float and int types separately
+ try:
+ return np.finfo(self.as_numpy_dtype()).max
+ except: # bare except as possible raises by finfo not documented
+ try:
+ return np.iinfo(self.as_numpy_dtype()).max
+ except:
+ raise TypeError("Cannot find maximum value of %s." % self)
+
+ def is_compatible_with(self, other):
+ """Returns True if the `other` DType will be converted to this DType.
+
+ The conversion rules are as follows:
+
+ ```
+ DType(T) .is_compatible_with(DType(T)) == True
+ DType(T) .is_compatible_with(DType(T).as_ref) == True
+ DType(T).as_ref.is_compatible_with(DType(T)) == False
+ DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True
+ ```
+
+ Args:
+ other: A `DType` (or object that may be converted to a `DType`).
+
+ Returns:
+ True if a Tensor of the `other` `DType` will be implicitly converted to
+ this `DType`.
+ """
+ other = as_dtype(other)
+ return self._type_enum in (
+ other.as_datatype_enum, other.base_dtype.as_datatype_enum)
+
+ def __eq__(self, other):
+ """Returns True iff this DType refers to the same type as `other`."""
+ return (other is not None
+ and self._type_enum == as_dtype(other).as_datatype_enum)
+
+ def __ne__(self, other):
+ """Returns True iff self != other."""
+ return not self.__eq__(other)
+
+ @property
+ def name(self):
+ """Returns the string name for this `DType`."""
+ return _TYPE_TO_STRING[self._type_enum]
+
+ def __str__(self):
+ return "<dtype: %r>" % self.name
+
+ def __repr__(self):
+ return "tf." + self.name
+
+
+# Define standard wrappers for the types_pb2.DataType enum.
+float32 = DType(types_pb2.DT_FLOAT)
+float64 = DType(types_pb2.DT_DOUBLE)
+double = float64
+int32 = DType(types_pb2.DT_INT32)
+uint8 = DType(types_pb2.DT_UINT8)
+int16 = DType(types_pb2.DT_INT16)
+int8 = DType(types_pb2.DT_INT8)
+string = DType(types_pb2.DT_STRING)
+complex64 = DType(types_pb2.DT_COMPLEX64)
+int64 = DType(types_pb2.DT_INT64)
+bool = DType(types_pb2.DT_BOOL)
+qint8 = DType(types_pb2.DT_QINT8)
+quint8 = DType(types_pb2.DT_QUINT8)
+qint32 = DType(types_pb2.DT_QINT32)
+bfloat16 = DType(types_pb2.DT_BFLOAT16)
+float32_ref = DType(types_pb2.DT_FLOAT_REF)
+float64_ref = DType(types_pb2.DT_DOUBLE_REF)
+double_ref = float64_ref
+int32_ref = DType(types_pb2.DT_INT32_REF)
+uint8_ref = DType(types_pb2.DT_UINT8_REF)
+int16_ref = DType(types_pb2.DT_INT16_REF)
+int8_ref = DType(types_pb2.DT_INT8_REF)
+string_ref = DType(types_pb2.DT_STRING_REF)
+complex64_ref = DType(types_pb2.DT_COMPLEX64_REF)
+int64_ref = DType(types_pb2.DT_INT64_REF)
+bool_ref = DType(types_pb2.DT_BOOL_REF)
+qint8_ref = DType(types_pb2.DT_QINT8_REF)
+quint8_ref = DType(types_pb2.DT_QUINT8_REF)
+qint32_ref = DType(types_pb2.DT_QINT32_REF)
+bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
+
+
+# Maintain an intern table so that we don't have to create a large
+# number of small objects.
+_INTERN_TABLE = {
+ types_pb2.DT_FLOAT: float32,
+ types_pb2.DT_DOUBLE: float64,
+ types_pb2.DT_INT32: int32,
+ types_pb2.DT_UINT8: uint8,
+ types_pb2.DT_INT16: int16,
+ types_pb2.DT_INT8: int8,
+ types_pb2.DT_STRING: string,
+ types_pb2.DT_COMPLEX64: complex64,
+ types_pb2.DT_INT64: int64,
+ types_pb2.DT_BOOL: bool,
+ types_pb2.DT_QINT8: qint8,
+ types_pb2.DT_QUINT8: quint8,
+ types_pb2.DT_QINT32: qint32,
+ types_pb2.DT_BFLOAT16: bfloat16,
+ types_pb2.DT_FLOAT_REF: float32_ref,
+ types_pb2.DT_DOUBLE_REF: float64_ref,
+ types_pb2.DT_INT32_REF: int32_ref,
+ types_pb2.DT_UINT8_REF: uint8_ref,
+ types_pb2.DT_INT16_REF: int16_ref,
+ types_pb2.DT_INT8_REF: int8_ref,
+ types_pb2.DT_STRING_REF: string_ref,
+ types_pb2.DT_COMPLEX64_REF: complex64_ref,
+ types_pb2.DT_INT64_REF: int64_ref,
+ types_pb2.DT_BOOL_REF: bool_ref,
+ types_pb2.DT_QINT8_REF: qint8_ref,
+ types_pb2.DT_QUINT8_REF: quint8_ref,
+ types_pb2.DT_QINT32_REF: qint32_ref,
+ types_pb2.DT_BFLOAT16_REF: bfloat16_ref,
+}
+
+
+# Standard mappings between types_pb2.DataType values and string names.
+_TYPE_TO_STRING = {
+ types_pb2.DT_FLOAT: "float32",
+ types_pb2.DT_DOUBLE: "float64",
+ types_pb2.DT_INT32: "int32",
+ types_pb2.DT_UINT8: "uint8",
+ types_pb2.DT_INT16: "int16",
+ types_pb2.DT_INT8: "int8",
+ types_pb2.DT_STRING: "string",
+ types_pb2.DT_COMPLEX64: "complex64",
+ types_pb2.DT_INT64: "int64",
+ types_pb2.DT_BOOL: "bool",
+ types_pb2.DT_QINT8: "qint8",
+ types_pb2.DT_QUINT8: "quint8",
+ types_pb2.DT_QINT32: "qint32",
+ types_pb2.DT_BFLOAT16: "bfloat16",
+ types_pb2.DT_FLOAT_REF: "float32_ref",
+ types_pb2.DT_DOUBLE_REF: "float64_ref",
+ types_pb2.DT_INT32_REF: "int32_ref",
+ types_pb2.DT_UINT8_REF: "uint8_ref",
+ types_pb2.DT_INT16_REF: "int16_ref",
+ types_pb2.DT_INT8_REF: "int8_ref",
+ types_pb2.DT_STRING_REF: "string_ref",
+ types_pb2.DT_COMPLEX64_REF: "complex64_ref",
+ types_pb2.DT_INT64_REF: "int64_ref",
+ types_pb2.DT_BOOL_REF: "bool_ref",
+ types_pb2.DT_QINT8_REF: "qint8_ref",
+ types_pb2.DT_QUINT8_REF: "quint8_ref",
+ types_pb2.DT_QINT32_REF: "qint32_ref",
+ types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
+}
+_STRING_TO_TF = {value: _INTERN_TABLE[key]
+ for key, value in _TYPE_TO_STRING.iteritems()}
+# Add non-canonical aliases.
+_STRING_TO_TF["float"] = float32
+_STRING_TO_TF["float_ref"] = float32_ref
+_STRING_TO_TF["double"] = float64
+_STRING_TO_TF["double_ref"] = float64_ref
+
+
+# Numpy representation for quantized dtypes.
+#
+# These are magic strings that are used in the swig wrapper to identify
+# quantized types.
+# TODO(mrry,keveman): Investigate Numpy type registration to replace this
+# hard-coding of names.
+_np_qint8 = np.dtype([("qint8", np.int8, 1)])
+_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
+_np_qint32 = np.dtype([("qint32", np.int32, 1)])
+
+# Standard mappings between types_pb2.DataType values and numpy.dtypes.
+_NP_TO_TF = frozenset([
+ (np.float32, float32),
+ (np.float64, float64),
+ (np.int32, int32),
+ (np.int64, int64),
+ (np.uint8, uint8),
+ (np.int16, int16),
+ (np.int8, int8),
+ (np.complex64, complex64),
+ (np.object, string),
+ (np.bool, bool),
+ (_np_qint8, qint8),
+ (_np_quint8, quint8),
+ (_np_qint32, qint32),
+ # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16.
+])
+_TF_TO_NP = {
+ types_pb2.DT_FLOAT: np.float32,
+ types_pb2.DT_DOUBLE: np.float64,
+ types_pb2.DT_INT32: np.int32,
+ types_pb2.DT_UINT8: np.uint8,
+ types_pb2.DT_INT16: np.int16,
+ types_pb2.DT_INT8: np.int8,
+ # NOTE(mdevin): For strings we use np.object as it supports variable length
+ # strings.
+ types_pb2.DT_STRING: np.object,
+ types_pb2.DT_COMPLEX64: np.complex64,
+ types_pb2.DT_INT64: np.int64,
+ types_pb2.DT_BOOL: np.bool,
+ types_pb2.DT_QINT8: _np_qint8,
+ types_pb2.DT_QUINT8: _np_quint8,
+ types_pb2.DT_QINT32: _np_qint32,
+ types_pb2.DT_BFLOAT16: np.uint16,
+
+ # Ref types
+ types_pb2.DT_FLOAT_REF: np.float32,
+ types_pb2.DT_DOUBLE_REF: np.float64,
+ types_pb2.DT_INT32_REF: np.int32,
+ types_pb2.DT_UINT8_REF: np.uint8,
+ types_pb2.DT_INT16_REF: np.int16,
+ types_pb2.DT_INT8_REF: np.int8,
+ types_pb2.DT_STRING_REF: np.object,
+ types_pb2.DT_COMPLEX64_REF: np.complex64,
+ types_pb2.DT_INT64_REF: np.int64,
+ types_pb2.DT_BOOL_REF: np.bool,
+ types_pb2.DT_QINT8_REF: _np_qint8,
+ types_pb2.DT_QUINT8_REF: _np_quint8,
+ types_pb2.DT_QINT32_REF: _np_qint32,
+ types_pb2.DT_BFLOAT16_REF: np.uint16,
+}
+
+
+QUANTIZED_DTYPES = frozenset(
+ [qint8, quint8, qint32, qint8_ref, quint8_ref, qint32_ref])
+
+
+def as_dtype(type_value):
+ """Converts the given `type_value` to a `DType`.
+
+ Args:
+ type_value: A value that can be converted to a `tf.DType`
+ object. This may currently be a `tf.DType` object, a
+ [`DataType` enum](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/types.proto),
+ a string type name, or a `numpy.dtype`.
+
+ Returns:
+ A `DType` corresponding to `type_value`.
+
+ Raises:
+ TypeError: If `type_value` cannot be converted to a `DType`.
+ """
+ if isinstance(type_value, DType):
+ return type_value
+
+ try:
+ return _INTERN_TABLE[type_value]
+ except KeyError:
+ pass
+
+ try:
+ return _STRING_TO_TF[type_value]
+ except KeyError:
+ pass
+
+ if isinstance(type_value, np.dtype):
+ # The numpy dtype for strings is variable length. We can not compare
+ # dtype with a single constant (np.string does not exist) to decide
+ # dtype is a "string" type. We need to compare the dtype.type to be
+ # sure it's a string type.
+ if type_value.type == np.string_ or type_value.type == np.unicode_:
+ return string
+
+ for key, val in _NP_TO_TF:
+ if key == type_value:
+ return val
+
+ raise TypeError(
+ "Cannot convert value %r to a TensorFlow DType." % type_value)
diff --git a/tensorflow/python/framework/types_test.py b/tensorflow/python/framework/types_test.py
new file mode 100644
index 0000000000..acd2994339
--- /dev/null
+++ b/tensorflow/python/framework/types_test.py
@@ -0,0 +1,174 @@
+"""Tests for tensorflow.python.framework.importer."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.platform import googletest
+
+
+class TypesTest(test_util.TensorFlowTestCase):
+
+ def testAllTypesConstructible(self):
+ for datatype_enum in types_pb2.DataType.values():
+ if datatype_enum == types_pb2.DT_INVALID:
+ continue
+ self.assertEqual(
+ datatype_enum, types.DType(datatype_enum).as_datatype_enum)
+
+ def testAllTypesConvertibleToDType(self):
+ for datatype_enum in types_pb2.DataType.values():
+ if datatype_enum == types_pb2.DT_INVALID:
+ continue
+ self.assertEqual(
+ datatype_enum, types.as_dtype(datatype_enum).as_datatype_enum)
+
+ def testAllTypesConvertibleToNumpyDtype(self):
+ for datatype_enum in types_pb2.DataType.values():
+ if datatype_enum == types_pb2.DT_INVALID:
+ continue
+ dtype = types.as_dtype(datatype_enum)
+ numpy_dtype = dtype.as_numpy_dtype
+ _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype)
+ if dtype.base_dtype != types.bfloat16:
+ # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16.
+ self.assertEqual(
+ types.as_dtype(datatype_enum).base_dtype, types.as_dtype(numpy_dtype))
+
+ def testInvalid(self):
+ with self.assertRaises(TypeError):
+ types.DType(types_pb2.DT_INVALID)
+ with self.assertRaises(TypeError):
+ types.as_dtype(types_pb2.DT_INVALID)
+
+ def testNumpyConversion(self):
+ self.assertIs(types.float32, types.as_dtype(np.float32))
+ self.assertIs(types.float64, types.as_dtype(np.float64))
+ self.assertIs(types.int32, types.as_dtype(np.int32))
+ self.assertIs(types.int64, types.as_dtype(np.int64))
+ self.assertIs(types.uint8, types.as_dtype(np.uint8))
+ self.assertIs(types.int16, types.as_dtype(np.int16))
+ self.assertIs(types.int8, types.as_dtype(np.int8))
+ self.assertIs(types.complex64, types.as_dtype(np.complex64))
+ self.assertIs(types.string, types.as_dtype(np.object))
+ self.assertIs(types.string, types.as_dtype(np.array(["foo", "bar"]).dtype))
+ self.assertIs(types.bool, types.as_dtype(np.bool))
+ with self.assertRaises(TypeError):
+ types.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)]))
+
+ def testStringConversion(self):
+ self.assertIs(types.float32, types.as_dtype("float32"))
+ self.assertIs(types.float64, types.as_dtype("float64"))
+ self.assertIs(types.int32, types.as_dtype("int32"))
+ self.assertIs(types.uint8, types.as_dtype("uint8"))
+ self.assertIs(types.int16, types.as_dtype("int16"))
+ self.assertIs(types.int8, types.as_dtype("int8"))
+ self.assertIs(types.string, types.as_dtype("string"))
+ self.assertIs(types.complex64, types.as_dtype("complex64"))
+ self.assertIs(types.int64, types.as_dtype("int64"))
+ self.assertIs(types.bool, types.as_dtype("bool"))
+ self.assertIs(types.qint8, types.as_dtype("qint8"))
+ self.assertIs(types.quint8, types.as_dtype("quint8"))
+ self.assertIs(types.qint32, types.as_dtype("qint32"))
+ self.assertIs(types.bfloat16, types.as_dtype("bfloat16"))
+ self.assertIs(types.float32_ref, types.as_dtype("float32_ref"))
+ self.assertIs(types.float64_ref, types.as_dtype("float64_ref"))
+ self.assertIs(types.int32_ref, types.as_dtype("int32_ref"))
+ self.assertIs(types.uint8_ref, types.as_dtype("uint8_ref"))
+ self.assertIs(types.int16_ref, types.as_dtype("int16_ref"))
+ self.assertIs(types.int8_ref, types.as_dtype("int8_ref"))
+ self.assertIs(types.string_ref, types.as_dtype("string_ref"))
+ self.assertIs(types.complex64_ref, types.as_dtype("complex64_ref"))
+ self.assertIs(types.int64_ref, types.as_dtype("int64_ref"))
+ self.assertIs(types.bool_ref, types.as_dtype("bool_ref"))
+ self.assertIs(types.qint8_ref, types.as_dtype("qint8_ref"))
+ self.assertIs(types.quint8_ref, types.as_dtype("quint8_ref"))
+ self.assertIs(types.qint32_ref, types.as_dtype("qint32_ref"))
+ self.assertIs(types.bfloat16_ref, types.as_dtype("bfloat16_ref"))
+ with self.assertRaises(TypeError):
+ types.as_dtype("not_a_type")
+
+ def testDTypesHaveUniqueNames(self):
+ dtypes = []
+ names = set()
+ for datatype_enum in types_pb2.DataType.values():
+ if datatype_enum == types_pb2.DT_INVALID:
+ continue
+ dtype = types.as_dtype(datatype_enum)
+ dtypes.append(dtype)
+ names.add(dtype.name)
+ self.assertEqual(len(dtypes), len(names))
+
+ def testIsInteger(self):
+ self.assertEqual(types.as_dtype("int8").is_integer, True)
+ self.assertEqual(types.as_dtype("int16").is_integer, True)
+ self.assertEqual(types.as_dtype("int32").is_integer, True)
+ self.assertEqual(types.as_dtype("int64").is_integer, True)
+ self.assertEqual(types.as_dtype("uint8").is_integer, True)
+ self.assertEqual(types.as_dtype("complex64").is_integer, False)
+ self.assertEqual(types.as_dtype("float").is_integer, False)
+ self.assertEqual(types.as_dtype("double").is_integer, False)
+ self.assertEqual(types.as_dtype("string").is_integer, False)
+ self.assertEqual(types.as_dtype("bool").is_integer, False)
+
+ def testMinMax(self):
+ # make sure min/max evaluates for all data types that have min/max
+ for datatype_enum in types_pb2.DataType.values():
+ if datatype_enum == types_pb2.DT_INVALID:
+ continue
+ dtype = types.as_dtype(datatype_enum)
+ numpy_dtype = dtype.as_numpy_dtype
+
+ # ignore types for which there are no minimum/maximum (or we cannot
+ # compute it, such as for the q* types)
+ if (dtype.is_quantized or
+ dtype.base_dtype == types.bool or
+ dtype.base_dtype == types.string or
+ dtype.base_dtype == types.complex64):
+ continue
+
+ print "%s: %s - %s" % (dtype, dtype.min, dtype.max)
+
+ # check some values that are known
+ if numpy_dtype == np.bool_:
+ self.assertEquals(dtype.min, 0)
+ self.assertEquals(dtype.max, 1)
+ if numpy_dtype == np.int8:
+ self.assertEquals(dtype.min, -128)
+ self.assertEquals(dtype.max, 127)
+ if numpy_dtype == np.int16:
+ self.assertEquals(dtype.min, -32768)
+ self.assertEquals(dtype.max, 32767)
+ if numpy_dtype == np.int32:
+ self.assertEquals(dtype.min, -2147483648)
+ self.assertEquals(dtype.max, 2147483647)
+ if numpy_dtype == np.int64:
+ self.assertEquals(dtype.min, -9223372036854775808)
+ self.assertEquals(dtype.max, 9223372036854775807)
+ if numpy_dtype == np.uint8:
+ self.assertEquals(dtype.min, 0)
+ self.assertEquals(dtype.max, 255)
+ if numpy_dtype == np.uint16:
+ self.assertEquals(dtype.min, 0)
+ self.assertEquals(dtype.max, 4294967295)
+ if numpy_dtype == np.uint32:
+ self.assertEquals(dtype.min, 0)
+ self.assertEquals(dtype.max, 18446744073709551615)
+ if numpy_dtype in (np.float16, np.float32, np.float64):
+ self.assertEquals(dtype.min, np.finfo(numpy_dtype).min)
+ self.assertEquals(dtype.max, np.finfo(numpy_dtype).max)
+
+ def testRepr(self):
+ for enum, name in types._TYPE_TO_STRING.iteritems():
+ dtype = types.DType(enum)
+ self.assertEquals(repr(dtype), 'tf.' + name)
+ dtype2 = eval(repr(dtype))
+ self.assertEquals(type(dtype2), types.DType)
+ self.assertEquals(dtype, dtype2)
+
+
+if __name__ == "__main__":
+ googletest.main()