diff options
Diffstat (limited to 'tensorflow/python/framework')
29 files changed, 10209 insertions, 0 deletions
diff --git a/tensorflow/python/framework/__init__.py b/tensorflow/python/framework/__init__.py new file mode 100755 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/python/framework/__init__.py diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py new file mode 100644 index 0000000000..676e5f779a --- /dev/null +++ b/tensorflow/python/framework/device.py @@ -0,0 +1,220 @@ +"""Class to represent a device.""" +import copy + + +class Device(object): + """Represents a Device.""" + + def __init__(self, job=None, replica=None, task=None, device_type=None, + device_index=None): + """Create a new device object. + + Args: + job: string. Optional device job name. + replica: int. Optional replica index. + task: int. Optional task index. + device_type: Optional device type string (e.g. "CPU" or "GPU") + device_index: int. Optional device index. If left + unspecified, device represents 'any' device_index. + """ + self.job = job + self.replica = replica + self.task = task + if device_type == "cpu" or device_type == "gpu": + # For backwards compatibility only, we support lowercase variants of + # cpu and gpu but turn them into uppercase here. + self.device_type = device_type.upper() + else: + self.device_type = device_type + self.device_index = device_index + + def _clear(self): + self._job = None + self._replica = None + self._task = None + self.device_type = None + self.device_index = None + + @property + def job(self): + return self._job + + @job.setter + def job(self, job): + if job is not None: + self._job = str(job) + else: + self._job = None + + @property + def replica(self): + return self._replica + + @replica.setter + def replica(self, replica): + if replica is not None: + self._replica = int(replica) + else: + self._replica = None + + @property + def task(self): + return self._task + + @task.setter + def task(self, task): + if task is not None: + self._task = int(task) + else: + self._task = None + + def parse_from_string(self, spec): + """Parse a Device name into its components. + + Args: + spec: a string of the form + /job:<name>/replica:<id>/task:<id>/device:CPU:<id> + or + /job:<name>/replica:<id>/task:<id>/device:GPU:<id> + as cpu and gpu are mutually exclusive. + All entries are optional. + + Returns: + The Device, for convenience. + + Raises: + ValueError: if the spec was not valid. + """ + self._clear() + splits = [x.split(":") for x in spec.split("/")] + for y in splits: + ly = len(y) + if y: + # NOTE(mdevin): we use the property getters here. + if ly == 2 and y[0] == "job": + self.job = y[1] + elif ly == 2 and y[0] == "replica": + self.replica = y[1] + elif ly == 2 and y[0] == "task": + self.task = y[1] + elif ((ly == 1 or ly == 2) and + ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))): + if self.device_type is not None: + raise ValueError("Cannot specify multiple device types: %s" % spec) + self.device_type = y[0].upper() + if ly == 2 and y[1] != "*": + self.device_index = int(y[1]) + elif ly == 3 and y[0] == "device": + if self.device_type is not None: + raise ValueError("Cannot specify multiple device types: %s" % spec) + self.device_type = y[1] + if y[2] != "*": + self.device_index = int(y[2]) + elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison + raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec)) + + return self + + def merge_from(self, dev): + """Merge the properties of "dev" into this Device. + + Args: + dev: a Device. + """ + if dev.job is not None: + self.job = dev.job + if dev.replica is not None: + self.replica = dev.replica + if dev.task is not None: + self.task = dev.task + if dev.device_type is not None: + self.device_type = dev.device_type + if dev.device_index is not None: + self.device_index = dev.device_index + + def to_string(self): + """Return a Device specification string. + + Returns: + a string of the form /job:<name>/replica:<id>/task:<id>/device:cpu:<id> + or /job:<name>/replica:<id>/task:<id>/device:cpu:<id>. + """ + dev = "" + if self.job is not None: + dev += "/job:" + self.job + if self.replica is not None: + dev += "/replica:" + str(self.replica) + if self.task is not None: + dev += "/task:" + str(self.task) + if self.device_type is not None: + device_index_string = "*" + if self.device_index is not None: + device_index_string = str(self.device_index) + dev += "/device:%s:%s" % (self.device_type, device_index_string) + return dev + + +def from_string(spec): + """Construct a Device from a string. + + Args: + spec: a string of the form + /job:<name>/replica:<id>/task:<id>/device:CPU:<id> + or + /job:<name>/replica:<id>/task:<id>/device:GPU:<id> + as cpu and gpu are mutually exclusive. + All entries are optional. + + Returns: + A Device. + """ + return Device().parse_from_string(spec) + + +def check_valid(spec): + """Check that a device spec is valid. + + Args: + spec: a string. + + Raises: + An exception if the spec is invalid. + """ + # Construct a device. It will assert a failure if spec is invalid. + from_string(spec) + + +def merge_device(spec): + """Returns a device function that merges devices specifications. + + This can be used to merge partial specifications of devices. The + innermost setting for a device field takes precedence. For example: + + with tf.Device(MergeDevice("/device:GPU:0")) + # Nodes created here have device "/device:GPU:0" + with tf.Device(MergeDevice("/job:worker")): + # Nodes created here have device "/job:worker/device:GPU:0" + with tf.Device(MergeDevice("/device:CPU:0")): + # Nodes created here have device "/job:worker/device:CPU:0" + with tf.Device(MergeDevice("/job:ps")): + # Nodes created here have device "/job:ps/device:CPU:0" + + Args: + spec: A device or a device spec string (partially) describing the + device that should be used for all nodes created in the scope of + the returned device function's with block. + + Returns: + A device function with the above-described behavior. + + Raises: + ValueError: if the spec was not valid. + """ + if not isinstance(spec, Device): + spec = from_string(spec or "") + def _device_function(node_def): + current_device = from_string(node_def.device or "") + copy_spec = copy.copy(spec) + copy_spec.merge_from(current_device) # current_device takes precedence. + return copy_spec + return _device_function diff --git a/tensorflow/python/framework/device_test.py b/tensorflow/python/framework/device_test.py new file mode 100644 index 0000000000..0a244b0815 --- /dev/null +++ b/tensorflow/python/framework/device_test.py @@ -0,0 +1,122 @@ +"""Tests for tensorflow.python.framework.device.""" +import tensorflow.python.platform + +from tensorflow.python.framework import device +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class DeviceTest(test_util.TensorFlowTestCase): + + def testEmpty(self): + d = device.Device() + self.assertEquals("", d.ToString()) + d.parse_from_string("") + self.assertEquals("", d.ToString()) + + def testConstructor(self): + d = device.Device(job="j", replica=0, task=1, + device_type="CPU", device_index=2) + self.assertEquals("j", d.job) + self.assertEquals(0, d.replica) + self.assertEquals(1, d.task) + self.assertEquals("CPU", d.device_type) + self.assertEquals(2, d.device_index) + self.assertEquals("/job:j/replica:0/task:1/device:CPU:2", d.to_string()) + + d = device.Device(device_type="GPU", device_index=0) + self.assertEquals("/device:GPU:0", d.to_string()) + + def testto_string(self): + d = device.Device() + d.job = "foo" + self.assertEquals("/job:foo", d.to_string()) + d.task = 3 + self.assertEquals("/job:foo/task:3", d.to_string()) + d.device_type = "CPU" + d.device_index = 0 + self.assertEquals("/job:foo/task:3/device:CPU:0", d.to_string()) + d.task = None + d.replica = 12 + self.assertEquals("/job:foo/replica:12/device:CPU:0", d.to_string()) + d.device_type = "GPU" + d.device_index = 2 + self.assertEquals("/job:foo/replica:12/device:GPU:2", d.to_string()) + d.device_type = "CPU" + d.device_index = 1 + self.assertEquals("/job:foo/replica:12/device:CPU:1", d.to_string()) + d.device_type = None + d.device_index = None + d.cpu = None + self.assertEquals("/job:foo/replica:12", d.to_string()) + + # Test wildcard + d = device.Device(job="foo", replica=12, task=3, device_type="GPU") + self.assertEquals("/job:foo/replica:12/task:3/device:GPU:*", d.to_string()) + + def testParse(self): + d = device.Device() + d.parse_from_string("/job:foo/replica:0") + self.assertEquals("/job:foo/replica:0", d.to_string()) + d.parse_from_string("/replica:1/task:0/cpu:0") + self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string()) + d.parse_from_string("/replica:1/task:0/device:CPU:0") + self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string()) + d.parse_from_string("/job:muu/gpu:2") + self.assertEquals("/job:muu/device:GPU:2", d.to_string()) + with self.assertRaises(Exception) as e: + d.parse_from_string("/job:muu/gpu:2/cpu:0") + self.assertTrue("Cannot specify multiple device" in e.exception.message) + + def testFromString(self): + d = device.from_string("/job:foo/replica:0") + self.assertEquals("/job:foo/replica:0", d.to_string()) + with self.assertRaises(Exception) as e: + d = device.from_string("/job:muu/gpu:2/cpu:0") + self.assertTrue("Cannot specify multiple device" in e.exception.message) + + d = device.from_string("/job:foo/replica:0/task:3/cpu:*") + self.assertEquals(None, d.device_index) + d = device.from_string("/job:foo/replica:0/task:3/gpu:7") + self.assertEquals(7, d.device_index) + d = device.from_string("/job:foo/replica:0/task:3/device:GPU:7") + self.assertEquals(7, d.device_index) + + def testMerge(self): + d = device.from_string("/job:foo/replica:0") + self.assertEquals("/job:foo/replica:0", d.to_string()) + d.merge_from(device.from_string("/task:1/gpu:2")) + self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string()) + + d = device.Device() + d.merge_from(device.from_string("/task:1/cpu:0")) + self.assertEquals("/task:1/device:CPU:0", d.to_string()) + d.merge_from(device.from_string("/job:boo/gpu:0")) + self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string()) + d.merge_from(device.from_string("/job:muu/cpu:2")) + self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string()) + d.merge_from(device.from_string("/job:muu/device:MyFunnyDevice:2")) + self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string()) + + def testCheckValid(self): + device.CheckValid("/job:foo/replica:0") + + with self.assertRaises(Exception) as e: + device.CheckValid("/job:j/replica:foo") + self.assertTrue("invalid literal for int" in e.exception.message) + + with self.assertRaises(Exception) as e: + device.CheckValid("/job:j/task:bar") + self.assertTrue("invalid literal for int" in e.exception.message) + + with self.assertRaises(Exception) as e: + device.CheckValid("/bar:muu/baz:2") + self.assertTrue("Unknown attribute: 'bar'" in e.exception.message) + + with self.assertRaises(Exception) as e: + device.CheckValid("/cpu:0/gpu:2") + self.assertTrue("Cannot specify multiple device" in e.exception.message) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py new file mode 100644 index 0000000000..68dbb3df72 --- /dev/null +++ b/tensorflow/python/framework/docs.py @@ -0,0 +1,492 @@ +"""Updates generated docs from Python doc comments. + +Both updates the files in the file-system and executes g4 commands to +make sure any changes are ready to be submitted. +""" + +import inspect +import os +import re +import sys + + +_arg_re = re.compile(" *([*]{0,2}[a-zA-Z][a-zA-Z0-9_]*):") +_section_re = re.compile("([A-Z][a-zA-Z ]*):$") +_always_drop_symbol_re = re.compile("_[_a-zA-Z0-9]") +_anchor_re = re.compile(r"^[\w.]+$") +_member_mark = "@@" + + +class Document(object): + """Base class for an automatically generated document.""" + + def write_markdown_to_file(self, f): + """Writes a Markdown-formatted version of this document to file `f`. + + Args: + f: The output file. + """ + raise NotImplementedError("Document.WriteToFile") + + +class Index(Document): + """An automatically generated index for a collection of documents.""" + + def __init__(self, module_to_name, members, filename_to_library_map): + """Creates a new Index. + + Args: + module_to_name: Dictionary mapping modules to short names. + members: Dictionary mapping member name to (fullname, member). + filename_to_library_map: A list of (filename, Library) pairs. The order + corresponds to the order in which the libraries appear in the index. + """ + self._module_to_name = module_to_name + self._members = members + self._filename_to_library_map = filename_to_library_map + + def write_markdown_to_file(self, f): + """Writes this index to file `f`. + + The output is formatted as an unordered list. Each list element + contains the title of the library, followed by a list of symbols + in that library hyperlinked to the corresponding anchor in that + library. + + Args: + f: The output file. + """ + print >>f, "<!-- This file is machine generated: DO NOT EDIT! -->" + print >>f, "" + print >>f, "# TensorFlow Python reference documentation" + print >>f, "" + for filename, library in self._filename_to_library_map: + per_symbol_links = [] + for name in sorted(library.mentioned): + if name in self._members: + fullname, member = self._members[name] + anchor = _get_anchor(self._module_to_name, fullname) + prefix = "class " * inspect.isclass(member) + per_symbol_links.append("[%s%s](%s#%s)" % + (prefix, name, filename, anchor)) + if per_symbol_links: + print >>f, "* <b>[%s](%s)</b>: %s" % (library.title, filename, + ",\n ".join(per_symbol_links)) + print >>f, "" + + # actually include the files right here + print >>f, '<div class="sections-order" style="display: none;">\n<!--' + for filename, _ in self._filename_to_library_map: + print >>f, "<!-- %s -->" % filename + print >>f, "-->\n</div>" + +def collect_members(module_to_name): + """Collect all symbols from a list of modules. + + Args: + module_to_name: Dictionary mapping modules to short names. + + Returns: + Dictionary mapping name to (fullname, member) pairs. + """ + members = {} + for module, module_name in module_to_name.iteritems(): + for name, member in inspect.getmembers(module): + if ((inspect.isfunction(member) or inspect.isclass(member)) and + not _always_drop_symbol_re.match(name)): + fullname = '%s.%s' % (module_name, name) + if name in members: + other_fullname, other_member = members[name] + if member is not other_member: + raise RuntimeError("Short name collision between %s and %s" % + (fullname, other_fullname)) + if len(fullname) == len(other_fullname): + raise RuntimeError("Can't decide whether to use %s or %s for %s: " + "both full names have length %d" % + (fullname, other_fullname, len(fullname))) + if len(fullname) > len(other_fullname): + continue # Use the shorter full name + members[name] = fullname, member + return members + + +def _get_anchor(module_to_name, fullname): + """Turn a full member name into an anchor. + + Args: + module_to_name: Dictionary mapping modules to short names. + fullname: Fully qualified name of symbol. + + Returns: + HTML anchor string. The longest module name prefix of fullname is + removed to make the anchor. + + Raises: + ValueError: If fullname uses characters invalid in an anchor. + """ + if not _anchor_re.match(fullname): + raise ValueError("'%s' is not a valid anchor" % fullname) + anchor = fullname + for module_name in module_to_name.itervalues(): + if fullname.startswith(module_name + "."): + rest = fullname[len(module_name)+1:] + # Use this prefix iff it is longer than any found before + if len(anchor) > len(rest): + anchor = rest + return anchor + + +class Library(Document): + """An automatically generated document for a set of functions and classes.""" + + def __init__(self, + title, + module, + module_to_name, + members, + documented, + exclude_symbols=(), + catch_all=False): + """Creates a new Library. + + Args: + title: A human-readable title for the library. + module: Module to pull high level docstring from (for table of contents, + list of Ops to document, etc.). + module_to_name: Dictionary mapping modules to short names. + members: Dictionary mapping member name to (fullname, member). + documented: Set of documented names to update. + exclude_symbols: A list of specific symbols to exclude. + """ + self._title = title + self._module = module + self._module_to_name = module_to_name + self._members = dict(members) # Copy since we mutate it below + self._exclude_symbols = frozenset(exclude_symbols) + documented.update(exclude_symbols) + self._documented = documented + self._mentioned = set() + + @property + def title(self): + """The human-readable title for this library.""" + return self._title + + @property + def mentioned(self): + """Set of names mentioned in this library.""" + return self._mentioned + + @property + def exclude_symbols(self): + """Set of excluded symbols.""" + return self._exclude_symbols + + def _should_include_member(self, name, member): + """Returns True if this member should be included in the document.""" + # Always exclude symbols matching _always_drop_symbol_re. + if _always_drop_symbol_re.match(name): + return False + # Finally, exclude any specifically-excluded symbols. + if name in self._exclude_symbols: + return False + return True + + def get_imported_modules(self, module): + """Returns the list of modules imported from `module`.""" + for name, member in inspect.getmembers(module): + if inspect.ismodule(member): + yield name, member + + def get_class_members(self, cls_name, cls): + """Returns the list of class members to document in `cls`. + + This function filters the class member to ONLY return those + defined by the class. It drops the inherited ones. + + Args: + cls_name: Qualified name of `cls`. + cls: An inspect object of type 'class'. + + Yields: + name, member tuples. + """ + for name, member in inspect.getmembers(cls): + # Only show methods and properties presently. + if not (inspect.ismethod(member) or isinstance(member, property)): + continue + if ((inspect.ismethod(member) and member.__name__ == "__init__") + or self._should_include_member(name, member)): + yield name, ("%s.%s" % (cls_name, name), member) + + def _generate_signature_for_function(self, func): + """Given a function, returns a string representing its args.""" + args_list = [] + argspec = inspect.getargspec(func) + first_arg_with_default = ( + len(argspec.args or []) - len(argspec.defaults or [])) + for arg in argspec.args[:first_arg_with_default]: + if arg == "self": + # Python documentation typically skips `self` when printing method + # signatures. + continue + args_list.append(arg) + if argspec.defaults: + for arg, default in zip( + argspec.args[first_arg_with_default:], argspec.defaults): + args_list.append("%s=%r" % (arg, default)) + if argspec.varargs: + args_list.append("*" + argspec.varargs) + if argspec.keywords: + args_list.append("**" + argspec.keywords) + return "(" + ", ".join(args_list) + ")" + + def _remove_docstring_indent(self, docstring): + """Remove indenting. + + We follow Python's convention and remove the minimum indent of the lines + after the first, see: + https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation + preserving relative indentation. + + Args: + docstring: A docstring. + + Returns: + A list of strings, one per line, with the minimum indent stripped. + """ + docstring = docstring or "" + lines = docstring.strip().split("\n") + + min_indent = len(docstring) + for l in lines[1:]: + l = l.rstrip() + if l: + i = 0 + while i < len(l) and l[i] == " ": + i += 1 + if i < min_indent: min_indent = i + for i in range(1, len(lines)): + l = lines[i].rstrip() + if len(l) >= min_indent: + l = l[min_indent:] + lines[i] = l + return lines + + def _print_formatted_docstring(self, docstring, f): + """Formats the given `docstring` as Markdown and prints it to `f`.""" + lines = self._remove_docstring_indent(docstring) + + # Output the lines, identifying "Args" and other section blocks. + i = 0 + + def _at_start_of_section(): + """Returns the header if lines[i] is at start of a docstring section.""" + l = lines[i] + match = _section_re.match(l) + if match and i + 1 < len( + lines) and lines[i + 1].startswith(" "): + return match.group(1) + else: + return None + + while i < len(lines): + l = lines[i] + + section_header = _at_start_of_section() + if section_header: + if i == 0 or lines[i-1]: + print >>f, "" + # Use at least H4 to keep these out of the TOC. + print >>f, "##### " + section_header + ":" + print >>f, "" + i += 1 + outputting_list = False + while i < len(lines): + l = lines[i] + # A new section header terminates the section. + if _at_start_of_section(): + break + match = _arg_re.match(l) + if match: + if not outputting_list: + # We need to start a list. In Markdown, a blank line needs to + # precede a list. + print >>f, "" + outputting_list = True + suffix = l[len(match.group()):].lstrip() + print >>f, "* <b>" + match.group(1) + "</b>: " + suffix + else: + # For lines that don't start with _arg_re, continue the list if it + # has enough indentation. + outputting_list &= l.startswith(" ") + print >>f, l + i += 1 + else: + print >>f, l + i += 1 + + def _print_function(self, f, prefix, fullname, func): + """Prints the given function to `f`.""" + heading = prefix + " " + fullname + if not isinstance(func, property): + heading += self._generate_signature_for_function(func) + heading += " {#%s}" % _get_anchor(self._module_to_name, fullname) + print >>f, heading + print >>f, "" + self._print_formatted_docstring(inspect.getdoc(func), f) + print >>f, "" + + def _write_member_markdown_to_file(self, f, name, member): + """Print `member` to `f`.""" + if inspect.isfunction(member): + print >>f, "- - -" + print >>f, "" + self._print_function(f, "###", name, member) + print >>f, "" + elif inspect.ismethod(member): + print >>f, "- - -" + print >>f, "" + self._print_function(f, "####", name, member) + print >>f, "" + elif isinstance(member, property): + print >>f, "- - -" + print >>f, "" + self._print_function(f, "####", name, member) + elif inspect.isclass(member): + print >>f, "- - -" + print >>f, "" + print >>f, "### class %s {#%s}" % ( + name, _get_anchor(self._module_to_name, name)) + print >>f, "" + self._write_class_markdown_to_file(f, name, member) + print >>f, "" + else: + raise RuntimeError("Member %s has unknown type %s" % (name, type(member))) + + def _write_docstring_markdown_to_file(self, f, docstring, members, imports): + for l in self._remove_docstring_indent(docstring): + if l.startswith(_member_mark): + name = l[len(_member_mark):].strip(" \t") + if name in members: + self._documented.add(name) + self._mentioned.add(name) + self._write_member_markdown_to_file(f, *members[name]) + del members[name] + elif name in imports: + self._write_module_markdown_to_file(f, imports[name]) + else: + raise ValueError("%s: unknown member `%s`" % (self._title, name)) + else: + print >>f, l + + def _write_class_markdown_to_file(self, f, name, cls): + """Write the class doc to 'f'. + + Args: + f: File to write to. + prefix: Prefix for names. + cls: class object. + name: name to use. + """ + # Build the list of class methods to document. + methods = dict(self.get_class_members(name, cls)) + # Used later to check if any methods were called out in the class + # docstring. + num_methods = len(methods) + self._write_docstring_markdown_to_file(f, inspect.getdoc(cls), methods, {}) + + # If some methods were not described, describe them now if they are + # defined by the class itself (not inherited). If NO methods were + # described, describe all methods. + # + # TODO(mdevin): when all methods have been categorized make it an error + # if some methods are not categorized. + any_method_called_out = (len(methods) != num_methods) + if any_method_called_out: + other_methods = {n: m for n, m in methods.iteritems() + if n in cls.__dict__} + if other_methods: + print >>f, "\n#### Other Methods" + else: + other_methods = methods + for name in sorted(other_methods): + self._write_member_markdown_to_file(f, *other_methods[name]) + + def _write_module_markdown_to_file(self, f, module): + imports = dict(self.get_imported_modules(module)) + self._write_docstring_markdown_to_file(f, inspect.getdoc(module), + self._members, imports) + + def write_markdown_to_file(self, f): + """Prints this library to file `f`. + + Args: + f: File to write to. + + Returns: + Dictionary of documented members. + """ + print >>f, "<!-- This file is machine generated: DO NOT EDIT! -->" + print >>f, "" + # TODO(mdevin): Do not insert these. Let the doc writer put them in + # the module docstring explicitly. + print >>f, "#", self._title + print >>f, "[TOC]" + print >>f, "" + if self._module is not None: + self._write_module_markdown_to_file(f, self._module) + + def write_other_members(self, f, catch_all=False): + """Writes the leftover members to `f`. + + Args: + f: File to write to. + catch_all: If true, document all missing symbols from any module. + Otherwise, document missing symbols from just this module. + """ + if catch_all: + names = self._members.iteritems() + else: + names = inspect.getmembers(self._module) + leftovers = [] + for name, _ in names: + if name in self._members and name not in self._documented: + leftovers.append(name) + if leftovers: + print "%s: undocumented members: %d" % (self._title, len(leftovers)) + print >>f, "\n## Other Functions and Classes" + for name in sorted(leftovers): + print " %s" % name + self._documented.add(name) + self._mentioned.add(name) + self._write_member_markdown_to_file(f, *self._members[name]) + + def assert_no_leftovers(self): + """Generate an error if there are leftover members.""" + leftovers = [] + for name in self._members.iterkeys(): + if name in self._members and name not in self._documented: + leftovers.append(name) + if leftovers: + raise RuntimeError("%s: undocumented members: %s" % + (self._title, ", ".join(leftovers))) + + +def write_libraries(dir, libraries): + """Write a list of libraries to disk. + + Args: + dir: Output directory. + libraries: List of (filename, library) pairs. + """ + files = [open(os.path.join(dir, k), "w") for k, _ in libraries] + # Document mentioned symbols for all libraries + for f, (_, v) in zip(files, libraries): + v.write_markdown_to_file(f) + # Document symbols that no library mentioned. We do this after writing + # out all libraries so that earlier libraries know what later libraries + # documented. + for f, (_, v) in zip(files, libraries): + v.write_other_members(f) + f.close() diff --git a/tensorflow/python/framework/errors.py b/tensorflow/python/framework/errors.py new file mode 100644 index 0000000000..fe8f107cec --- /dev/null +++ b/tensorflow/python/framework/errors.py @@ -0,0 +1,410 @@ +"""Exception types for TensorFlow errors.""" +import traceback +import warnings + +from tensorflow.core.lib.core import error_codes_pb2 + + +class OpError(Exception): + """A generic error that is raised when TensorFlow execution fails. + + Whenever possible, the session will raise a more specific subclass + of `OpError` from the `tf.errors` module. + + @@op + @@node_def + """ + + def __init__(self, node_def, op, message, error_code): + """Creates a new OpError indicating that a particular op failed. + + Args: + node_def: The graph_pb2.NodeDef proto representing the op that failed. + op: The ops.Operation that failed, if known; otherwise None. + message: The message string describing the failure. + error_code: The error_codes_pb2.Code describing the error. + """ + super(OpError, self).__init__() + self._message = message + self._node_def = node_def + self._op = op + self._error_code = error_code + + @property + def message(self): + """The error message that describes the error.""" + return self._message + + @property + def op(self): + """The operation that failed, if known. + + *N.B.* If the failed op was synthesized at runtime, e.g. a `Send` + or `Recv` op, there will be no corresponding + [`Operation`](framework.md#Operation) object. In that case, this + will return `None`, and you should instead use the + [`node_def`](OpError.node_def) to discover information about the op. + + Returns: + The `Operation` that failed, or None. + """ + return self._op + + @property + def error_code(self): + """The integer error code that describes the error.""" + return self._error_code + + @property + def node_def(self): + """The `NodeDef` proto representing the op that failed.""" + return self._node_def + + def __str__(self): + if self._op is not None: + output = ["%s\nCaused by op %r, defined at:\n" + % (self.message, self._op.name,)] + curr_traceback_list = traceback.format_list(self._op.traceback) + output.extend(curr_traceback_list) + original_op = self._op._original_op + while original_op is not None: + output.append( + "\n...which was originally created as op %r, defined at:\n" + % (original_op.name,)) + prev_traceback_list = curr_traceback_list + curr_traceback_list = traceback.format_list(original_op.traceback) + + # Attempt to elide large common subsequences of the subsequent + # stack traces. + # + # TODO(mrry): Consider computing the actual longest common subsequence. + is_eliding = False + elide_count = 0 + last_elided_line = None + for line, line_in_prev in zip(curr_traceback_list, prev_traceback_list): + if line == line_in_prev: + if is_eliding: + elide_count += 1 + last_elided_line = line + else: + output.append(line) + is_eliding = True + elide_count = 0 + else: + if is_eliding: + if elide_count > 0: + output.extend( + ["[elided %d identical lines from previous traceback]\n" + % (elide_count - 1,), last_elided_line]) + is_eliding = False + output.extend(line) + + original_op = original_op._original_op + return ''.join(output) + else: + return self.message + + +OK = error_codes_pb2.OK +CANCELLED = error_codes_pb2.CANCELLED +UNKNOWN = error_codes_pb2.UNKNOWN +INVALID_ARGUMENT = error_codes_pb2.INVALID_ARGUMENT +DEADLINE_EXCEEDED = error_codes_pb2.DEADLINE_EXCEEDED +NOT_FOUND = error_codes_pb2.NOT_FOUND +ALREADY_EXISTS = error_codes_pb2.ALREADY_EXISTS +PERMISSION_DENIED = error_codes_pb2.PERMISSION_DENIED +UNAUTHENTICATED = error_codes_pb2.UNAUTHENTICATED +RESOURCE_EXHAUSTED = error_codes_pb2.RESOURCE_EXHAUSTED +FAILED_PRECONDITION = error_codes_pb2.FAILED_PRECONDITION +ABORTED = error_codes_pb2.ABORTED +OUT_OF_RANGE = error_codes_pb2.OUT_OF_RANGE +UNIMPLEMENTED = error_codes_pb2.UNIMPLEMENTED +INTERNAL = error_codes_pb2.INTERNAL +UNAVAILABLE = error_codes_pb2.UNAVAILABLE +DATA_LOSS = error_codes_pb2.DATA_LOSS + + +class CancelledError(OpError): + """Raised when an operation or step is cancelled. + + For example, a long-running operation (e.g. + [`queue.enqueue()`](io_ops.md#QueueBase.enqueue) may be cancelled by + running another operation (e.g. + [`queue.close(cancel_pending_enqueues=True)`](io_ops.md#QueueBase.close), + or by [closing the session](client.md#Session.close). A step that is + running such a long-running operation will fail by raising `CancelledError`. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates a `CancelledError`.""" + super(CancelledError, self).__init__(node_def, op, message, CANCELLED) + + +class UnknownError(OpError): + """Unknown error. + + An example of where this error may be returned is if a Status value + received from another address space belongs to an error-space that + is not known to this address space. Also errors raised by APIs that + do not return enough error information may be converted to this + error. + + @@__init__ + """ + + def __init__(self, node_def, op, message, error_code=UNKNOWN): + """Creates an `UnknownError`.""" + super(UnknownError, self).__init__(node_def, op, message, error_code) + + +class InvalidArgumentError(OpError): + """Raised when an operation receives an invalid argument. + + This may occur, for example, if an operation is receives an input + tensor that has an invalid value or shape. For example, the + [`tf.matmul()`](math_ops.md#matmul) op will raise this error if it + receives an input that is not a matrix, and the + [`tf.reshape()`](array_ops.md#reshape) op will raise this error if + the new shape does not match the number of elements in the input + tensor. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates an `InvalidArgumentError`.""" + super(InvalidArgumentError, self).__init__(node_def, op, message, + INVALID_ARGUMENT) + + +class DeadlineExceededError(OpError): + """Raised when a deadline expires before an operation could complete. + + This exception is not currently used. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates a `DeadlineExceededError`.""" + super(DeadlineExceededError, self).__init__(node_def, op, message, + DEADLINE_EXCEEDED) + + +class NotFoundError(OpError): + """Raised when a requested entity (e.g., a file or directory) was not found. + + For example, running the + [`tf.WholeFileReader.read()`](io_ops.md#WholeFileReader) operation + could raise `NotFoundError` if it receives the name of a file that + does not exist. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates a `NotFoundError`.""" + super(NotFoundError, self).__init__(node_def, op, message, NOT_FOUND) + + +class AlreadyExistsError(OpError): + """Raised when an entity that we attempted to create already exists. + + For example, running an operation that saves a file + (e.g. [`tf.train.Saver.save()`](train.md#Saver.save)) could + potentially raise this exception if an explicit filename for an + existing file was passed. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates an `AlreadyExistsError`.""" + super(AlreadyExistsError, self).__init__(node_def, op, message, + ALREADY_EXISTS) + + +class PermissionDeniedError(OpError): + """Raised when the caller does not have permission to run an operation. + + For example, running the + [`tf.WholeFileReader.read()`](io_ops.md#WholeFileReader) operation + could raise `PermissionDeniedError` if it receives the name of a + file for which the user does not have the read file permission. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates a `PermissionDeniedError`.""" + super(PermissionDeniedError, self).__init__(node_def, op, message, + PERMISSION_DENIED) + + +class UnauthenticatedError(OpError): + """The request does not have valid authentication credentials. + + This exception is not currently used. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates an `UnauthenticatedError`.""" + super(UnauthenticatedError, self).__init__(node_def, op, message, + UNAUTHENTICATED) + + +class ResourceExhaustedError(OpError): + """Some resource has been exhausted. + + For example, this error might be raised if a per-user quota is + exhausted, or perhaps the entire file system is out of space. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates a `ResourceExhaustedError`.""" + super(ResourceExhaustedError, self).__init__(node_def, op, message, + RESOURCE_EXHAUSTED) + + +class FailedPreconditionError(OpError): + """Operation was rejected because the system is not in a state to execute it. + + This exception is most commonly raised when running an operation + that reads a [`tf.Variable`](state_ops.md#Variable) before it has + been initialized. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates a `FailedPreconditionError`.""" + super(FailedPreconditionError, self).__init__(node_def, op, message, + FAILED_PRECONDITION) + + +class AbortedError(OpError): + """The operation was aborted, typically due to a concurrent action. + + For example, running a [`queue.enqueue()`](io_ops.md#QueueBase.enqueue) + operation may raise `AbortedError` if a + [`queue.close()`](io_ops.md@QueueBase.close) operation previously ran. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates an `AbortedError`.""" + super(AbortedError, self).__init__(node_def, op, message, ABORTED) + + +class OutOfRangeError(OpError): + """Raised when an operation executed past the valid range. + + This exception is raised in "end-of-file" conditions, such as when a + [`queue.dequeue()`](io_ops.md#QueueBase.dequeue) operation is + blocked on an empty queue, and a + [`queue.close()`](io_ops.md#QueueBase.close) operation executes. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates an `OutOfRangeError`.""" + super(OutOfRangeError, self).__init__(node_def, op, message, + OUT_OF_RANGE) + + +class UnimplementedError(OpError): + """Raised when an operation has not been implemented. + + Some operations may raise this error when passed otherwise-valid + arguments that it does not currently support. For example, running + the [`tf.nn.max_pool()`](nn.md#max_pool) operation would raise this + error if pooling was requested on the batch dimension, because this + is not yet supported. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates an `UnimplementedError`.""" + super(UnimplementedError, self).__init__(node_def, op, message, + UNIMPLEMENTED) + + +class InternalError(OpError): + """Raised when the system experiences an internal error. + + This exception is raised when some invariant expected by the runtime + has been broken. Catching this exception is not recommended. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates an `InternalError`.""" + super(InternalError, self).__init__(node_def, op, message, INTERNAL) + + +class UnavailableError(OpError): + """Raised when the runtime is currently unavailable. + + This exception is not currently used. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates an `UnavailableError`.""" + super(UnavailableError, self).__init__(node_def, op, message, + UNAVAILABLE) + + +class DataLossError(OpError): + """Raised when unrecoverable data loss or corruption is encountered. + + For example, this may be raised by running a + [`tf.WholeFileReader.read()`](io_ops.md#WholeFileReader) operation, + if the file is truncated while it is being read. + + @@__init__ + """ + + def __init__(self, node_def, op, message): + """Creates a `DataLossError`.""" + super(DataLossError, self).__init__(node_def, op, message, DATA_LOSS) + + +_CODE_TO_EXCEPTION_CLASS = { + CANCELLED: CancelledError, + UNKNOWN: UnknownError, + INVALID_ARGUMENT: InvalidArgumentError, + DEADLINE_EXCEEDED: DeadlineExceededError, + NOT_FOUND: NotFoundError, + ALREADY_EXISTS: AlreadyExistsError, + PERMISSION_DENIED: PermissionDeniedError, + UNAUTHENTICATED: UnauthenticatedError, + RESOURCE_EXHAUSTED: ResourceExhaustedError, + FAILED_PRECONDITION: FailedPreconditionError, + ABORTED: AbortedError, + OUT_OF_RANGE: OutOfRangeError, + UNIMPLEMENTED: UnimplementedError, + INTERNAL: InternalError, + UNAVAILABLE: UnavailableError, + DATA_LOSS: DataLossError, +} + + +def _make_specific_exception(node_def, op, message, error_code): + try: + exc_type = _CODE_TO_EXCEPTION_CLASS[error_code] + return exc_type(node_def, op, message) + except KeyError: + warnings.warn("Unknown error code: %d" % error_code) + return UnknownError(node_def, op, message, error_code) diff --git a/tensorflow/python/framework/errors_test.py b/tensorflow/python/framework/errors_test.py new file mode 100644 index 0000000000..ab59a729f6 --- /dev/null +++ b/tensorflow/python/framework/errors_test.py @@ -0,0 +1,63 @@ +"""Tests for tensorflow.python.framework.errors.""" +import tensorflow.python.platform + +import warnings + +import tensorflow as tf + +from tensorflow.core.lib.core import error_codes_pb2 + +class ErrorsTest(tf.test.TestCase): + + def testUniqueClassForEachErrorCode(self): + for error_code, exc_type in [ + (tf.errors.CANCELLED, tf.errors.CancelledError), + (tf.errors.UNKNOWN, tf.errors.UnknownError), + (tf.errors.INVALID_ARGUMENT, tf.errors.InvalidArgumentError), + (tf.errors.DEADLINE_EXCEEDED, tf.errors.DeadlineExceededError), + (tf.errors.NOT_FOUND, tf.errors.NotFoundError), + (tf.errors.ALREADY_EXISTS, tf.errors.AlreadyExistsError), + (tf.errors.PERMISSION_DENIED, tf.errors.PermissionDeniedError), + (tf.errors.UNAUTHENTICATED, tf.errors.UnauthenticatedError), + (tf.errors.RESOURCE_EXHAUSTED, tf.errors.ResourceExhaustedError), + (tf.errors.FAILED_PRECONDITION, tf.errors.FailedPreconditionError), + (tf.errors.ABORTED, tf.errors.AbortedError), + (tf.errors.OUT_OF_RANGE, tf.errors.OutOfRangeError), + (tf.errors.UNIMPLEMENTED, tf.errors.UnimplementedError), + (tf.errors.INTERNAL, tf.errors.InternalError), + (tf.errors.UNAVAILABLE, tf.errors.UnavailableError), + (tf.errors.DATA_LOSS, tf.errors.DataLossError), + ]: + # pylint: disable=protected-access + self.assertTrue(isinstance( + tf.errors._make_specific_exception(None, None, None, error_code), + exc_type)) + # pylint: enable=protected-access + + def testKnownErrorClassForEachErrorCodeInProto(self): + for error_code in error_codes_pb2.Code.values(): + # pylint: disable=line-too-long + if error_code in (error_codes_pb2.OK, + error_codes_pb2.DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_): + continue + # pylint: enable=line-too-long + with warnings.catch_warnings(record=True) as w: + # pylint: disable=protected-access + exc = tf.errors._make_specific_exception(None, None, None, error_code) + # pylint: enable=protected-access + self.assertEqual(0, len(w)) # No warning is raised. + self.assertTrue(isinstance(exc, tf.errors.OpError)) + self.assertTrue(tf.errors.OpError in exc.__class__.__bases__) + + def testUnknownErrorCodeCausesWarning(self): + with warnings.catch_warnings(record=True) as w: + # pylint: disable=protected-access + exc = tf.errors._make_specific_exception(None, None, None, 37) + # pylint: enable=protected-access + self.assertEqual(1, len(w)) + self.assertTrue("Unknown error code: 37" in str(w[0].message)) + self.assertTrue(isinstance(exc, tf.errors.OpError)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py new file mode 100644 index 0000000000..e317cfda8d --- /dev/null +++ b/tensorflow/python/framework/framework_lib.py @@ -0,0 +1,70 @@ +# pylint: disable=wildcard-import,unused-import,g-bad-import-order,line-too-long +"""Import names from the framework library. + +## Core graph data structures + +@@Graph +@@Operation +@@Tensor + +## Tensor types + +@@DType +@@as_dtype + +## Utility functions + +@@device +@@name_scope +@@control_dependencies +@@convert_to_tensor +@@get_default_graph +@@import_graph_def + +## Graph collections + +@@add_to_collection +@@get_collection +@@GraphKeys + +## Defining new operations + +@@RegisterGradient +@@NoGradient +@@RegisterShape +@@TensorShape +@@Dimension +@@op_scope +@@get_seed +""" + +# Classes used when building a Graph. +from tensorflow.python.framework.ops import Graph +from tensorflow.python.framework.ops import Operation +from tensorflow.python.framework.ops import Tensor +from tensorflow.python.framework.ops import SparseTensor +from tensorflow.python.framework.ops import SparseTensorValue +from tensorflow.python.framework.ops import IndexedSlices + +# Utilities used when building a Graph. +from tensorflow.python.framework.ops import device +from tensorflow.python.framework.ops import name_scope +from tensorflow.python.framework.ops import op_scope +from tensorflow.python.framework.ops import control_dependencies +from tensorflow.python.framework.ops import get_default_graph +from tensorflow.python.framework.ops import GraphKeys +from tensorflow.python.framework.ops import add_to_collection +from tensorflow.python.framework.ops import get_collection +from tensorflow.python.framework.ops import convert_to_tensor +from tensorflow.python.framework.random_seed import get_seed +from tensorflow.python.framework.random_seed import set_random_seed +from tensorflow.python.framework.importer import import_graph_def + +# Needed when you defined a new Op in C++. +from tensorflow.python.framework.ops import RegisterGradient +from tensorflow.python.framework.ops import NoGradient +from tensorflow.python.framework.ops import RegisterShape +from tensorflow.python.framework.tensor_shape import Dimension +from tensorflow.python.framework.tensor_shape import TensorShape + +from tensorflow.python.framework.types import * diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py new file mode 100644 index 0000000000..a726d880e7 --- /dev/null +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -0,0 +1,114 @@ +"""Updates generated docs from Python doc comments.""" + +import os.path + +import tensorflow.python.platform +import sys +import tensorflow as tf + +from tensorflow.python.framework import docs +from tensorflow.python.framework import framework_lib +from tensorflow.python.client import client_lib + + +tf.flags.DEFINE_string("out_dir", None, + "Directory to which docs should be written.") +tf.flags.DEFINE_boolean("print_hidden_regex", False, + "Dump a regular expression matching any hidden symbol") +FLAGS = tf.flags.FLAGS + + +def get_module_to_name(): + return {tf: 'tf', + tf.errors: 'tf.errors', + tf.image: 'tf.image', + tf.nn: 'tf.nn', + tf.train: 'tf.train', + tf.python_io: 'tf.python_io'} + +def all_libraries(module_to_name, members, documented): + # A list of (filename, docs.Library) pairs representing the individual files + # that we want to create. + def library(name, title, module=None, **args): + if module is None: + module = sys.modules["tensorflow.python.ops" + + ("" if name == "ops" else "." + name)] + return (name + ".md", docs.Library(title=title, + module_to_name=module_to_name, + members=members, + documented=documented, + module=module, + **args)) + return [ + # Splits of module 'tf'. + library("framework", "Building Graphs", framework_lib), + library("constant_op", "Constants, Sequences, and Random Values"), + library("state_ops", "Variables"), + library("array_ops", "Tensor Transformations", + exclude_symbols=["list_diff"]), + library("math_ops", "Math", + exclude_symbols=["sparse_matmul", "arg_min", "arg_max", + "lin_space", "sparse_segment_mean_grad"]), + library("control_flow_ops", "Control Flow"), + library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"]), + library("sparse_ops", "Sparse Tensors"), + library("io_ops", "Inputs and Readers", + exclude_symbols=["LookupTableBase", "HashTable", + "initialize_all_tables", + "string_to_hash_bucket"]), + library("python_io", "Data IO (Python functions)", tf.python_io), + library("nn", "Neural Network", tf.nn, + exclude_symbols=["deconv2d", "conv2d_backprop_input", + "conv2d_backprop_filter", "avg_pool_grad", + "max_pool_grad", "max_pool_grad_with_argmax", + "batch_norm_with_global_normalization_grad", + "lrn_grad", "relu6_grad", "softplus_grad", + "xw_plus_b", "relu_layer", "lrn", + "batch_norm_with_global_normalization", + "batch_norm_with_global_normalization_grad", + "all_candidate_sampler"]), + library('client', "Running Graphs", client_lib, + exclude_symbols=["InteractiveSession"]), + library("train", "Training", tf.train, + exclude_symbols=["Feature", "Features", "BytesList", "FloatList", + "Int64List", "Example", "InferenceExample", + "RankingExample", "SequenceExample"]), + ] + +_hidden_symbols = ["Event", "Summary", + "HistogramProto", "ConfigProto", "NodeDef", "GraphDef", + "GPUOptions", "SessionInterface", "BaseSession"] + +def main(unused_argv): + if not FLAGS.out_dir: + tf.logging.error("out_dir not specified") + return -1 + + # Document libraries + documented = set() + module_to_name = get_module_to_name() + members = docs.collect_members(module_to_name) + libraries = all_libraries(module_to_name, members, documented) + docs.write_libraries(FLAGS.out_dir, libraries) + + # Make it easy to search for hidden symbols + if FLAGS.print_hidden_regex: + hidden = set(_hidden_symbols) + for _, lib in libraries: + hidden.update(lib.exclude_symbols) + print r"hidden symbols regex = r'\b(%s)\b'" % "|".join(sorted(hidden)) + + # Verify that all symbols are mentioned in some library doc. + catch_all = docs.Library(title="Catch All", module=None, + exclude_symbols=_hidden_symbols, + module_to_name=module_to_name, members=members, + documented=documented) + catch_all.assert_no_leftovers() + + # Generate index + with open(os.path.join(FLAGS.out_dir, "index.md"), "w") as f: + docs.Index(module_to_name, members, libraries).write_markdown_to_file(f) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/python/framework/gen_docs_test.sh b/tensorflow/python/framework/gen_docs_test.sh new file mode 100755 index 0000000000..fda214d93c --- /dev/null +++ b/tensorflow/python/framework/gen_docs_test.sh @@ -0,0 +1,4 @@ +#!/bin/bash -eux +DIR=$TEST_SRCDIR/tensorflow/python +$DIR/gen_docs_combined --out_dir $TEST_TMPDIR +echo "PASS" diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py new file mode 100644 index 0000000000..6ad2a1b009 --- /dev/null +++ b/tensorflow/python/framework/importer.py @@ -0,0 +1,303 @@ +"""A utility function for importing TensorFlow graphs.""" +import contextlib + +import tensorflow.python.platform + +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import op_def_registry +from tensorflow.python.framework import ops +from tensorflow.python.framework import types as types_lib + + +# TODO(josh11b): SWIG the code from node_def_util instead of duplicating +# the logic here. +def _GetNodeAttr(node_def, attr_name): + if attr_name not in node_def.attr: + raise ValueError('Expected one attr with name %r in %s.' + % (attr_name, str(node_def))) + return node_def.attr[attr_name] + + +def _ArgToTypesNoRef(node_def, arg_def): + if arg_def.number_attr: + repeats = _GetNodeAttr(node_def, arg_def.number_attr).i + if arg_def.type_attr: + dtype = _GetNodeAttr(node_def, arg_def.type_attr).type + else: + assert arg_def.type != types_pb2.DT_INVALID + dtype = arg_def.type + return [dtype] * repeats + elif arg_def.type_attr: + return [_GetNodeAttr(node_def, arg_def.type_attr).type] + elif arg_def.type_list_attr: + return _GetNodeAttr(node_def, arg_def.type_list_attr).list.type + else: + assert arg_def.type != types_pb2.DT_INVALID + return [arg_def.type] + + +def _SingleArgToTypes(node_def, arg_def): + types = _ArgToTypesNoRef(node_def, arg_def) + if arg_def.is_ref: + return [types_lib.as_dtype(dt).as_ref.as_datatype_enum for dt in types] + return types + + +def _ArgsToTypes(node_def, arg_list): + types = [] + for arg_def in arg_list: + types.extend(_SingleArgToTypes(node_def, arg_def)) + return types + + +def _InputTypes(node_def, op_dict): + op_def = op_dict[node_def.op] + return _ArgsToTypes(node_def, op_def.input_arg) + + +def _OutputTypes(node_def, op_dict): + op_def = op_dict[node_def.op] + return _ArgsToTypes(node_def, op_def.output_arg) + + +def _IsControlInput(input_name): + # Expected format: '^operation_name' (control input). + return input_name.startswith('^') + + +def _ParseTensorName(tensor_name): + """Parses a tensor name into an operation name and output index. + + This function will canonicalize tensor names as follows: + + * "foo:0" -> ("foo", 0) + * "foo:7" -> ("foo", 7) + * "foo" -> ("foo", 0) + * "foo:bar:baz" -> ValueError + + Args: + tensor_name: The name of a tensor. + + Returns: + A tuple containing the operation name, and the output index. + + Raises: + ValueError: If `tensor_name' cannot be interpreted as the name of a tensor. + """ + components = tensor_name.split(':') + if len(components) == 2: + # Expected format: 'operation_name:output_index'. + try: + output_index = int(components[1]) + except ValueError: + raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,)) + return components[0], output_index + elif len(components) == 1: + # Expected format: 'operation_name' (implicit 0th output). + return components[0], 0 + else: + raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,)) + + +def _CanonicalInputName(input_name): + if _IsControlInput(input_name): + return input_name + input_op_name, output_index = _ParseTensorName(input_name) + return '%s:%d' % (input_op_name, output_index) + + +def _InvalidNodeMessage(node, message): + return 'graph_def is invalid at node %r: %s.' % (node.name, message) + + +@contextlib.contextmanager +def _MaybeDevice(device): + """Applies the given device only if device is not None or empty.""" + if device: + with ops.device(device): + yield + else: + yield + + +def import_graph_def(graph_def, input_map=None, return_elements=None, + name=None, op_dict=None): + """Imports the TensorFlow graph in `graph_def` into the Python `Graph`. + + This function provides a way to import a serialized TensorFlow + [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto) + protocol buffer, and extract individual objects in the `GraphDef` as + [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See + [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a + `GraphDef` proto. + + Args: + graph_def: A `GraphDef` proto containing operations to be imported into + the default graph. + input_map: A dictionary mapping input names (as strings) in `graph_def` + to `Tensor` objects. The values of the named input tensors in the + imported graph will be re-mapped to the respective `Tensor` values. + return_elements: A list of strings containing operation names in + `graph_def` that will be returned as `Operation` objects; and/or + tensor names in `graph_def` that will be returned as `Tensor` objects. + name: (Optional.) A prefix that will be prepended to the names in + `graph_def`. Defaults to `"import"`. + op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. + Must contain an `OpDef` proto for each op type named in `graph_def`. + If omitted, uses the `OpDef` protos registered in the global registry. + + Returns: + A list of `Operation` and/or `Tensor` objects from the imported graph, + corresponding to the names in `return_elements'. + + Raises: + TypeError: If `graph_def` is not a `GraphDef` proto, + `input_map' is not a dictionary mapping strings to `Tensor` objects, + or `return_elements` is not a list of strings. + ValueError: If `input_map`, or `return_elements` contains names that + do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. + it refers to an unknown tensor). + """ + # Type checks for inputs. + if not isinstance(graph_def, graph_pb2.GraphDef): + raise TypeError('graph_def must be a GraphDef proto.') + if input_map is None: + input_map = {} + else: + if not (isinstance(input_map, dict) + and all(isinstance(k, basestring) for k in input_map.keys())): + raise TypeError('input_map must be a dictionary mapping strings to ' + 'Tensor objects.') + if (return_elements is not None + and not (isinstance(return_elements, (list, tuple)) + and all(isinstance(x, basestring) for x in return_elements))): + raise TypeError('return_elements must be a list of strings.') + + # Use a canonical representation for all tensor names. + input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} + used_input_keys = set() + + name_to_op = {} + + if op_dict is None: + op_dict = op_def_registry.get_registered_ops() + + with ops.op_scope(input_map.values(), name, 'import'): + g = ops.get_default_graph() + + with ops.name_scope('_inputs'): + input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} + + # NOTE(mrry): We do this in two passes, because there may be a cycle in + # `graph_def'. + + # 1. Add operations without their inputs. + for node in graph_def.node: + output_types = _OutputTypes(node, op_dict) + with _MaybeDevice(node.device): + name_to_op[node.name] = g.create_op( + node.op, [], output_types, name=node.name, attrs=node.attr, + compute_shapes=False) + + # 2. Add inputs to the operations. + for node in graph_def.node: + op = name_to_op[node.name] + input_types = _InputTypes(node, op_dict) + + # NOTE(mrry): We cannot use zip here because control inputs do not appear + # in the list of input_types. + for i, input_name in enumerate( + [_CanonicalInputName(x) for x in node.input]): + + if _IsControlInput(input_name): + # (a) Input is a control input that should be taken from an op + # in "graph_def". + try: + source_op = name_to_op[input_name[1:]] + except KeyError: + raise ValueError( + _InvalidNodeMessage( + node, + 'Control input %r not found in graph_def.' % (input_name,))) + # pylint: disable=protected-access + op._add_control_input(source_op) + # pylint: enable=protected-access + + else: + try: + input_type = input_types[i] + except IndexError: + raise ValueError(_InvalidNodeMessage( + node, 'More inputs specified (%r) than the op expects.' + % (input_name,))) + + if input_name in input_map: + # (b) Input should be replaced by a tensor from the caller. + source_tensor = input_map[input_name] + used_input_keys.add(input_name) + + else: + # (c) Input should be taken from an op in `graph_def'. + operation_name, output_index = _ParseTensorName(input_name) + try: + source_op = name_to_op[operation_name] + source_tensor = source_op.values()[output_index] + except (KeyError, IndexError): + raise ValueError( + _InvalidNodeMessage( + node, + 'Input tensor %r not found in graph_def.' + % (input_name,))) + + try: + # pylint: disable=protected-access + op._add_input(source_tensor, dtype=input_type) + # pylint: enable=protected-access + except TypeError as te: + raise ValueError( + _InvalidNodeMessage(node, 'Input tensor %r %s' + % (input_name, te.message))) + + # pylint: disable=protected_access + if op._input_dtypes != input_types: + raise ValueError( + _InvalidNodeMessage( + node, + 'Input types mismatch (expected %r but got %r)' + % (", ".join(types_lib.as_dtype(x).name for x in input_types), + ", ".join(x.name for x in op._input_dtypes)))) + # pylint: enable=protected_access + + # Execute shape inference for this op. + # NOTE(mrry): If the graph contains a cycle, the full shape information + # may not be available for this op's inputs. + ops.set_shapes_for_outputs(op) + + # Treat unused input mappings as an error, because they are likely to be + # due to a typo. + unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys) + if unused_input_keys: + raise ValueError( + 'Attempted to map inputs that were not found in graph_def: [%s]' + % ', '.join(unused_input_keys)) + + if return_elements is None: + return None + else: + ret = [] + for name in return_elements: + if ':' in name: + try: + operation_name, output_index = _ParseTensorName(name) + ret.append(name_to_op[operation_name].outputs[output_index]) + except (ValueError, KeyError, IndexError): + raise ValueError( + 'Requested return_element %r not found in graph_def.' % name) + else: + try: + ret.append(name_to_op[name]) + except KeyError: + raise ValueError( + 'Requested return_element %r not found in graph_def.' % name) + return ret diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py new file mode 100644 index 0000000000..470092313a --- /dev/null +++ b/tensorflow/python/framework/importer_test.py @@ -0,0 +1,546 @@ +"""Tests for tensorflow.python.framework.importer.""" + +import tensorflow.python.platform + +import tensorflow as tf + +from google.protobuf import text_format + +from tensorflow.core.framework import op_def_pb2 +from tensorflow.python.framework import device +from tensorflow.python.framework import op_def_registry + + +_op_list = op_def_pb2.OpList() +text_format.Merge(""" + op { + name: 'None' + } + op { + name: 'Oi' + output_arg { name: 'a' type: DT_INT32 } + } + op { + name: 'Or' + output_arg { name: 'a' type: DT_INT32 is_ref: true } + } + op { + name: 'Of' + output_arg { name: 'a' type: DT_FLOAT } + } + op { + name: 'Ii' + input_arg { name: 'a' type: DT_INT32 } + } + op { + name: 'If' + input_arg { name: 'a' type: DT_FLOAT } + } + op { + name: 'Oii' + output_arg { name: 'a' type: DT_INT32 } + output_arg { name: 'b' type: DT_INT32 } + } + op { + name: 'Oif' + output_arg { name: 'a' type: DT_INT32 } + output_arg { name: 'b' type: DT_FLOAT } + } + op { + name: 'Iii' + input_arg { name: 'a' type: DT_INT32 } + input_arg { name: 'b' type: DT_INT32 } + } + op { + name: 'Iff' + input_arg { name: 'a' type: DT_FLOAT } + input_arg { name: 'b' type: DT_FLOAT } + } + op { + name: 'Iif' + input_arg { name: 'a' type: DT_INT32 } + input_arg { name: 'b' type: DT_FLOAT } + } + op { + name: 'Iri' + input_arg { name: 'a' type: DT_INT32 is_ref: true } + input_arg { name: 'b' type: DT_INT32 } + } + op { + name: 'In' + input_arg { name: 'a' number_attr: 'N' type_attr: 'T' } + attr { name: 'N' type: 'int' minimum: 1 } + attr { name: 'T' type: 'type' } + } + op { + name: 'Otl' + output_arg { name: 'a' type_list_attr: 't' } + attr { name: 'T' type: 'list(type)' minimum: 1 } + } + op { + name: 'Unary' + input_arg { name: 'a' type_attr: 'T' } + output_arg { name: 'b' type_attr: 'T' } + attr { name: 'T' type: 'type' } + } +""", _op_list) +op_def_registry.register_op_list(_op_list) +# NOTE(mrry): Dummy shape registrations for ops used in the tests. +for op_def in _op_list.op: + tf.RegisterShape(op_def.name)(None) + +class ImportGraphDefTest(tf.test.TestCase): + + def _MakeGraphDef(self, text): + ret = tf.GraphDef() + text_format.Merge(text, ret) + return ret + + def testBasic(self): + with tf.Graph().as_default(): + a, b, c, d = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oif' } + node { name: 'B' op: 'Otl' + attr { key: 't' + value { list { type: DT_INT32 type: DT_FLOAT } } } } + node { name: 'C' op: 'In' + attr { key: 'N' value { i: 2 } } + attr { key: 'T' value { type: DT_INT32 } } + input: 'A:0' input: 'B:0' } + node { name: 'D' op: 'In' + attr { key: 'N' value { i: 2 } } + attr { key: 'T' value { type: DT_FLOAT } } + input: 'A:1' input: 'B:1' } + """), + return_elements=['A', 'B', 'C', 'D'], + name='import') + + # Assert that the import process creates distinct tensors. + self.assertNotEqual(a.outputs[0].name, a.outputs[1].name) + self.assertNotEqual(b.outputs[0].name, b.outputs[1].name) + self.assertNotEqual(a.outputs[0].name, b.outputs[0].name) + self.assertNotEqual(a.outputs[0].name, b.outputs[1].name) + self.assertNotEqual(a.outputs[1].name, b.outputs[0].name) + self.assertNotEqual(a.outputs[1].name, b.outputs[1].name) + + # Assert that the ops are connected according to the GraphDef topology. + self.assertEqual(c.inputs[0], a.outputs[0]) + self.assertEqual(c.inputs[1], b.outputs[0]) + self.assertEqual(d.inputs[0], a.outputs[1]) + self.assertEqual(d.inputs[1], b.outputs[1]) + + # Check the types of the returned ops and tensors. + self.assertEqual(a.type, 'Oif') + self.assertEqual(b.type, 'Otl') + self.assertEqual(c.type, 'In') + self.assertEqual(d.type, 'In') + self.assertEqual(a.outputs[0].dtype, tf.int32) + self.assertEqual(a.outputs[1].dtype, tf.float32) + self.assertEqual(b.outputs[0].dtype, tf.int32) + self.assertEqual(b.outputs[1].dtype, tf.float32) + + # Check the names of the returned ops. + self.assertEqual(a.name, 'import/A') + self.assertEqual(b.name, 'import/B') + self.assertEqual(c.name, 'import/C') + self.assertEqual(d.name, 'import/D') + + def testInputMap(self): + with tf.Graph().as_default(): + feed_a_0 = tf.constant(0, dtype=tf.int32) + feed_b_1 = tf.constant(1, dtype=tf.int32) + + a, b, c, d = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oii' } + node { name: 'B' op: 'Oii' } + node { name: 'C' op: 'In' + attr { key: 'N' value { i: 2 } } + attr { key: 'T' value { type: DT_INT32 } } + input: 'A:0' input: 'B:0' } + node { name: 'D' op: 'In' + attr { key: 'N' value { i: 2 } } + attr { key: 'T' value { type: DT_INT32 } } + input: 'A:1' input: 'B:1' } + """), + input_map={'A:0': feed_a_0, 'B:1': feed_b_1}, + return_elements=['A', 'B', 'C', 'D']) + + self.assertEqual(c.inputs[0], feed_a_0) + self.assertEqual(c.inputs[1], b.outputs[0]) + self.assertEqual(d.inputs[0], a.outputs[1]) + self.assertEqual(d.inputs[1], feed_b_1) + + def testImplicitZerothOutput(self): + with tf.Graph().as_default(): + a, b = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oii' } + node { name: 'B' op: 'Ii' input: 'A' } + """), + return_elements=['A', 'B']) + + self.assertEqual(b.inputs[0], a.outputs[0]) + + def testInputMapImplicitZerothOutput(self): + with tf.Graph().as_default(): + feed_a_0 = tf.constant(0, dtype=tf.int32) + b, = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oii' } + node { name: 'B' op: 'Ii' input: 'A:0' } + """), + input_map={'A': feed_a_0}, + return_elements=['B']) + + self.assertEqual(b.inputs[0], feed_a_0) + + def testWithControlDependency(self): + with tf.Graph().as_default(): + a, b = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'None' } + node { name: 'B' op: 'None' input: '^A' } + """), + return_elements=['A', 'B']) + + self.assertEqual(b.control_inputs, [a]) + + def testWithRefs(self): + with tf.Graph().as_default(): + a, b, c, d = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Or' } + node { name: 'B' op: 'Oi' } + node { name: 'C' op: 'Iii' input: 'A:0' input: 'B:0' } + node { name: 'D' op: 'Iri' input: 'A:0' input: 'B:0' } + """), + return_elements=['A', 'B', 'C', 'D']) + + self.assertEqual(c.inputs[0], a.outputs[0]) + self.assertEqual(c.inputs[1], b.outputs[0]) + self.assertEqual(d.inputs[0], a.outputs[0]) + self.assertEqual(d.inputs[1], b.outputs[0]) + + self.assertEqual(a.outputs[0].dtype, tf.int32_ref) + self.assertEqual(c._input_dtypes, [tf.int32, tf.int32]) + self.assertEqual(c.outputs, []) + self.assertEqual(d._input_dtypes, + [tf.int32_ref, tf.int32]) + self.assertEqual(d.outputs, []) + + def testCyclic(self): + with tf.Graph().as_default(): + a, b = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Unary' + attr { key: 'T' value { type: DT_INT32 } } input: 'B:0' } + node { name: 'B' op: 'Unary' + attr { key: 'T' value { type: DT_INT32 } } input: 'A:0' } + """), + return_elements=['A', 'B']) + + self.assertEqual(a.inputs[0], b.outputs[0]) + self.assertEqual(b.inputs[0], a.outputs[0]) + + def testTypeMismatchInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + node { name: 'B' op: 'If' input: 'A:0' } + """)) + self.assertTrue( + 'Cannot convert a tensor of type int32 to an input of type float' in + str(e.exception)) + + def testInvalidSignatureTooManyInputsInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + node { name: 'B' op: 'None' input: 'A:0' } + """)) + self.assertTrue('More inputs specified (u\'A:0\') than the op expects' in + str(e.exception)) + + def testInvalidSignatureNotEnoughInputsInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + node { name: 'B' op: 'Iif' input: 'A:0' } + """)) + self.assertTrue('Input types mismatch (expected \'int32, float32\' but ' + 'got \'int32\')' in str(e.exception)) + + def testMissingInputOpInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'B' op: 'If' input: 'A:0' } + """)) + self.assertTrue('Input tensor %r not found' % (u'A:0',) in + str(e.exception)) + + def testMissingInputOpInGraphDefButAppearsInInputMap(self): + with tf.Graph().as_default(): + feed_a_0 = tf.constant(5.0) + b, = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'B' op: 'If' input: 'A:0' } + """), + input_map={'A:0': feed_a_0}, + return_elements=['B']) + self.assertEqual(b.inputs[0], feed_a_0) + + def testMissingInputTensorInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Of' } + node { name: 'B' op: 'If' input: 'A:1' } + """)) + self.assertTrue('Input tensor %r not found' % (u'A:1',) in + str(e.exception)) + + def testMissingControlInputInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'B' op: 'None' input: '^A' } + """)) + self.assertTrue('Control input %r not found' % (u'^A',) in + str(e.exception)) + + def testInvalidTensorNameOutputIndexInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'B' op: 'None' input: 'A:B' } + """)) + self.assertEqual( + 'Cannot convert %r to a tensor name.' % (u'A:B',), str(e.exception)) + + def testInvalidTensorNameInGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'B' op: 'None' input: 'A:B:0' } + """)) + self.assertEqual( + 'Cannot convert %r to a tensor name.' % (u'A:B:0',), str(e.exception)) + + def testMissingReturnOperation(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'None' } + """), + return_elements=['B']) + self.assertTrue('return_element %r not found in graph_def.' % ('B') in + str(e.exception)) + + def testMissingReturnTensor(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + """), + return_elements=['A:1']) + self.assertTrue('return_element %r not found in graph_def.' % ('A:1') in + str(e.exception)) + + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + """), + return_elements=['B:0']) + self.assertTrue('return_element %r not found in graph_def.' % ('B:0') in + str(e.exception)) + + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + """), + return_elements=['A:B:0']) + self.assertTrue('return_element %r not found in graph_def.' % ('A:B:0') in + str(e.exception)) + + def testMissingInputMap(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'None' } + """), + input_map={'B:0': tf.constant(5.0)}) + self.assertTrue('not found in graph_def: [B:0]' in str(e.exception)) + + def testInputMapTypeMismatch(self): + with tf.Graph().as_default(): + with self.assertRaises(ValueError) as e: + tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'Oi' } + node { name: 'B' op: 'Ii' input: 'A:0' } + """), + input_map={'A:0': tf.constant(5.0)}) + self.assertTrue( + 'Cannot convert a tensor of type float32 to an input of type int32.' + in str(e.exception)) + + def testNoReturns(self): + with tf.Graph().as_default() as g: + ret = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'None' } + """)) + self.assertEqual(ret, None) + + a = g.get_operation_by_name('import/A') + self.assertEqual(a.type, 'None') + + def testOverrideNamePrefix(self): + with tf.Graph().as_default(): + a, = tf.import_graph_def( + self._MakeGraphDef(""" + node { name: 'A' op: 'None' } + """), + return_elements=['A'], name='imported_graph') + self.assertEqual(a.name, 'imported_graph/A') + + def testEmptyGraph(self): + with tf.Graph().as_default() as g: + init_version = g.version + tf.import_graph_def(self._MakeGraphDef('')) + self.assertEqual(init_version, g.version) + + def testInvalidInputForGraphDef(self): + with tf.Graph().as_default(): + with self.assertRaises(TypeError) as e: + tf.import_graph_def('') + self.assertEqual( + 'graph_def must be a GraphDef proto.', str(e.exception)) + + def testInvalidInputForInputMap(self): + with tf.Graph().as_default(): + with self.assertRaises(TypeError) as e: + tf.import_graph_def(self._MakeGraphDef(''), + input_map=[tf.constant(5.0)]) + self.assertEqual('input_map must be a dictionary mapping strings to ' + 'Tensor objects.', str(e.exception)) + + def testInvalidInputForReturnOperations(self): + with tf.Graph().as_default(): + with self.assertRaises(TypeError) as e: + tf.import_graph_def(self._MakeGraphDef(''), return_elements=[7]) + self.assertEqual( + 'return_elements must be a list of strings.', str(e.exception)) + + def testWithExtensionAndAttr(self): + with tf.Graph().as_default() as g: + c = tf.constant(5.0, dtype=tf.float32, name='c') + tf.pack([c, c], name='pack') + gdef = g.as_graph_def() + + with self.test_session(): + pack, = tf.import_graph_def(gdef, return_elements=['pack']) + self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0]) + + def testWithDevice(self): + with tf.Graph().as_default() as g: + # No device. + a = tf.constant(3.0, name='a') + + with tf.device('/cpu:0'): + b = tf.constant(4.0, name='b') + with tf.device('/job:worker'): + c = tf.constant(5.0, name='c') + + gdef = g.as_graph_def() + + with tf.Graph().as_default(): + a2, b2, c2 = tf.import_graph_def( + gdef, return_elements=['a', 'b', 'c']) + self.assertEqual(a.device, a2.device) + self.assertEqual(b.device, b2.device) + self.assertEqual(c.device, c2.device) + + with tf.Graph().as_default(): + with tf.device(device.merge_device('/task:0')): + a3, b3, c3 = tf.import_graph_def( + gdef, return_elements=['a', 'b', 'c']) + self.assertEqual('/task:0', a3.device) + self.assertEqual('/task:0/device:CPU:0', b3.device) # canonicalized. + self.assertEqual(c.device + '/task:0', c3.device) + + with tf.Graph().as_default(): + with tf.device(device.merge_device('/job:ps')): + a4, b4, c4 = tf.import_graph_def( + gdef, return_elements=['a', 'b', 'c']) + self.assertEqual('/job:ps', a4.device) + self.assertEqual('/job:ps/device:CPU:0', b4.device) # canonicalized. + self.assertEqual(c.device, c4.device) # worker overrides ps. + + with tf.Graph().as_default(): + with tf.device(device.merge_device('/gpu:0')): + a5, b5, c5 = tf.import_graph_def( + gdef, return_elements=['a', 'b', 'c']) + self.assertEqual('/device:GPU:0', a5.device) + self.assertEqual('/device:CPU:0', b5.device) # cpu overrides gpu. + self.assertEqual(c.device + '/device:GPU:0', c5.device) + + def testGradient(self): + with tf.Graph().as_default() as g: + inputs = tf.placeholder(tf.float32, shape=[None, 100], name="input") + weights = tf.placeholder(tf.float32, shape=[100, 10], name="weights") + biases = tf.placeholder(tf.float32, shape=[10], name="biases") + activations = tf.nn.relu(tf.matmul(inputs, weights) + biases, + name="activations") + loss = tf.reduce_mean(activations, name="loss") + gdef = g.as_graph_def() + + with tf.Graph().as_default() as g: + input_placeholder = tf.placeholder(tf.float32, shape=[32, 100]) + weights_var = tf.Variable(tf.truncated_normal([100, 10]), name="weights") + biases_var = tf.Variable(tf.zeros(10), name="biases") + activations, loss = tf.import_graph_def( + gdef, + input_map={"input:0": input_placeholder, + "weights:0": weights_var, + "biases:0": biases_var}, + return_elements=["activations:0", "loss:0"]) + self.assertEqual([32, 10], activations.get_shape()) + self.assertEqual([], loss.get_shape()) + weights_grad, biases_grad = tf.gradients(loss, [weights_var, biases_var]) + self.assertEqual([100, 10], weights_grad.get_shape()) + self.assertEqual([10], biases_grad.get_shape()) + + def testLargeGraph(self): + with self.test_session(): + # The default message byte limit is 64M. Ours is 2G with a warning at 512. + # Adding a 150M entries float32 tensor should blow through the warning, + # but not the hard limit. + input_shape = [150, 1024, 1024] + tensor_input = tf.np.random.rand(*input_shape).astype(tf.np.float32) + t = tf.constant(tensor_input, shape=input_shape) + g = tf.identity(t) + g.eval() + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/python/framework/op_def_registry.py b/tensorflow/python/framework/op_def_registry.py new file mode 100644 index 0000000000..2ec8c94a10 --- /dev/null +++ b/tensorflow/python/framework/op_def_registry.py @@ -0,0 +1,23 @@ +"""Global registry for OpDefs.""" + +from tensorflow.core.framework import op_def_pb2 + + +_registered_ops = {} + + +def register_op_list(op_list): + """Register all the ops in an op_def_pb2.OpList.""" + if not isinstance(op_list, op_def_pb2.OpList): + raise TypeError("%s is %s, not an op_def_pb2.OpList" % + (op_list, type(op_list))) + for op_def in op_list.op: + if op_def.name in _registered_ops: + assert _registered_ops[op_def.name] == op_def + else: + _registered_ops[op_def.name] = op_def + + +def get_registered_ops(): + """Returns a dictionary mapping names to OpDefs.""" + return _registered_ops diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py new file mode 100644 index 0000000000..0b0442cea1 --- /dev/null +++ b/tensorflow/python/framework/ops.py @@ -0,0 +1,2985 @@ +"""Classes and functions used to construct graphs.""" +# pylint: disable=g-bad-name +import collections +import contextlib +import copy +import linecache +import re +import sys +import threading +import weakref + +import tensorflow.python.platform + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import registry +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import types + + +def _convert_stack(stack): + """Converts a stack extracted using _extract_stack() to a traceback stack. + + Args: + stack: A list of n 4-tuples, (filename, lineno, name, frame_globals). + + Returns: + A list of n 4-tuples (filename, lineno, name, code), where the code tuple + element is calculated from the corresponding elements of the input tuple. + """ + ret = [] + for filename, lineno, name, frame_globals in stack: + linecache.checkcache(filename) + line = linecache.getline(filename, lineno, frame_globals) + if line: + line = line.strip() + else: + line = None + ret.append((filename, lineno, name, line)) + return ret + + +# pylint: disable=line-too-long +def _extract_stack(): + """A lightweight re-implementation of traceback.extract_stack. + + NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for + each stack frame using linecache, which results in an abundance of stat() + calls. This implementation does not retrieve the code, and any consumer + should apply _convert_stack to the result to obtain a traceback that can + be formatted etc. using traceback methods. + + Returns: + A list of 4-tuples (filename, lineno, name, frame_globals) corresponding to + the call stack of the current thread. + """ + # pylint: enable=line-too-long + try: + raise ZeroDivisionError + except ZeroDivisionError: + f = sys.exc_info()[2].tb_frame.f_back + ret = [] + while f is not None: + lineno = f.f_lineno + co = f.f_code + filename = co.co_filename + name = co.co_name + frame_globals = f.f_globals + ret.append((filename, lineno, name, frame_globals)) + f = f.f_back + ret.reverse() + return ret + + +class Tensor(object): + """Represents a value produced by an `Operation`. + + A `Tensor` is a symbolic handle to one of the outputs of an + `Operation`. It does not hold the values of that operation's output, + but instead provides a means of computing those values in a + TensorFlow [`Session`](client.md#Session). + + This class has two primary purposes: + + 1. A `Tensor` can be passed as an input to another `Operation`. + This builds a dataflow connection between operations, which + enables TensorFlow to execute an entire `Graph` that represents a + large, multi-step computation. + + 2. After the graph has been launched in a session, the value of the + `Tensor` can be computed by passing it to + [`Session.run()`](client.md#Session.run). + `t.eval()` is a shortcut for calling + `tf.get_default_session().run(t)`. + + In the following example, `c`, `d`, and `e` are symbolic `Tensor` + objects, whereas `result` is a numpy array that stores a concrete + value: + + ```python + # Build a dataflow graph. + c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) + e = tf.matmul(c, d) + + # Construct a `Session` to execut the graph. + sess = tf.Session() + + # Execute the graph and store the value that `e` represents in `result`. + result = sess.run(e) + ``` + + @@dtype + @@name + @@value_index + @@graph + @@op + @@consumers + + @@eval + + @@get_shape + @@set_shape + + """ + + # List of Python operators that we allow to override. + OVERLOADABLE_OPERATORS = { + # Binary. + "__add__", "__radd__", + "__sub__", "__rsub__", + "__mul__", "__rmul__", + "__div__", "__rdiv__", + "__truediv__", "__rtruediv__", + "__mod__", "__rmod__", + "__lt__", "__le__", + "__gt__", "__ge__", + "__and__", "__rand__", + "__or__", "__ror__", + "__xor__", "__rxor__", + "__getitem__", + # Unary. + "__invert__", + "__neg__", "__abs__"} + + def __init__(self, op, value_index, dtype): + """Creates a new `Tensor`. + + Args: + op: An `Operation`. `Operation` that computes this tensor. + value_index: An `int`. Index of the operation's endpoint that produces + this tensor. + dtype: A `types.DType`. Type of data stored in this tensor. + + Raises: + TypeError: If the op is not an `Operation`. + """ + if not isinstance(op, Operation): + raise TypeError("op needs to be an Operation: %s" % op) + self._op = op + self._value_index = value_index + self._dtype = types.as_dtype(dtype) + self._shape = tensor_shape.unknown_shape() + # List of operations that use this Tensor as input. We maintain this list + # to easily navigate a computation graph. + self._consumers = [] + + @property + def op(self): + """The `Operation` that produces this tensor as an output.""" + return self._op + + @property + def dtype(self): + """The `DType` of elements in this tensor.""" + return self._dtype + + @property + def graph(self): + """The `Graph` that contains this tensor.""" + return self._op.graph + + @property + def name(self): + """The string name of this tensor.""" + if not self._op.name: + raise ValueError("Operation was not named: %s" % self._op) + return "%s:%d" % (self._op.name, self._value_index) + + @property + def device(self): + """The name of the device on which this tensor will be produced, or None.""" + return self._op.device + + def _shape_as_list(self): + if self._shape.ndims is not None: + return [dim.value for dim in self._shape.dims] + else: + return None + + def get_shape(self): + """Returns the `TensorShape` that represents the shape of this tensor. + + The shape is computed using shape inference functions that are + registered for each `Operation` type using `tf.RegisterShape`. + See [`TensorShape`](framework.md#TensorShape) for more details of what a shape + represents. + + The inferred shape of a tensor is used to provide shape + information without having to launch the graph in a session. This + can be used for debugging, and providing early error messages. For + example: + + ```python + c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + + print c.get_shape() + ==> TensorShape([Dimension(2), Dimension(3)]) + + d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]]) + + print d.get_shape() + ==> TensorShape([Dimension(4), Dimension(2)]) + + # Raises a ValueError, because `c` and `d` do not have compatible + # inner dimensions. + e = tf.matmul(c, d) + + f = tf.matmul(c, d, transpose_a=True, transpose_b=True) + + print f.get_shape() + ==> TensorShape([Dimension(3), Dimension(4)]) + ``` + + In some cases, the inferred shape may have unknown dimensions. If + the caller has additional information about the values of these + dimensions, `Tensor.set_shape()` can be used to augment the + inferred shape. + + Returns: + A `TensorShape` representing the shape of this tensor. + """ + return self._shape + + def set_shape(self, shape): + """Updates the shape of this tensor. + + This method can be called multiple times, and will merge the given + `shape` with the current shape of this tensor. It can be used to + provide additional information about the shape of this tensor that + cannot be inferred from the graph alone. For example, this can be used + to provide additional information about the shapes of images: + + ```python + _, image_data = tf.TFRecordReader(...).read(...) + image = tf.image.decode_png(image_data, channels=3) + + # The height and width dimensions of `image` are data dependent, and + # cannot be computed without executing the op. + print image.get_shape() + ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)]) + + # We know that each image in this dataset is 28 x 28 pixels. + image.set_shape([28, 28, 3]) + print image.get_shape() + ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)]) + ``` + + Args: + shape: A `TensorShape` representing the shape of this tensor. + + Raises: + ValueError: If `shape` is not compatible with the current shape of + this tensor. + """ + self._shape = self._shape.merge_with(shape) + + @property + def value_index(self): + """The index of this tensor in the outputs of its `Operation`.""" + return self._value_index + + def consumers(self): + """Returns a list of `Operation`s that consume this tensor. + + Returns: + A list of `Operation`s. + """ + return self._consumers + + def _add_consumer(self, consumer): + """Add a consumer to this tensor. + + Args: + consumer: an Operation. + + Raises: + TypeError: if the consumer is not an Operation. + """ + if not isinstance(consumer, Operation): + raise TypeError("Consumer must be an Operation: %s" % consumer) + self._consumers.append(consumer) + + def _as_node_def_input(self): + """Return a value to use for the NodeDef "input" attribute. + + The returned string can be used in a NodeDef "input" attribute + to indicate that the NodeDef uses this Tensor as input. + + Raises: + ValueError: if this Tensor's Operation does not have a name. + + Returns: + a string. + """ + if not self._op.name: + raise ValueError("Operation was not named: %s" % self._op) + if self._value_index == 0: + return self._op.name + else: + return "%s:%d" % (self._op.name, self._value_index) + + def __str__(self): + return "Tensor(\"%s\"%s%s%s)" % ( + self.name, + (", shape=%s" % self.get_shape()) + if self.get_shape().ndims is not None else "", + (", dtype=%s" % self._dtype.name) if self._dtype else "", + (", device=%s" % self.device) if self.device else "") + + def __hash__(self): + # Necessary to support Python's collection membership operators + return id(self) + + def __eq__(self, other): + # Necessary to support Python's collection membership operators + return id(self) == id(other) + + # NOTE(mrry): This enables the Tensor's overloaded "right" binary + # operators to run when the left operand is an ndarray, because it + # accords the Tensor class higher priority than an ndarray, or a + # numpy matrix. + # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ + # mechanism, which allows more control over how Tensors interact + # with ndarrays. + __array_priority__ = 100 + + @staticmethod + def _override_operator(operator, func): + """Overrides (string) operator on Tensors to call func. + + Args: + operator: the string name of the operator to override. + func: the function that replaces the overriden operator. + + Raises: + ValueError: If operator has already been overwritten, + or if operator is not allowed to be overwritten. + """ + if getattr(Tensor, operator, None) is not None: + # check to see if this is a default method-wrapper which will be true + # for the comparison operators. + if not isinstance(getattr(Tensor, operator, None), type(all.__call__)): + raise ValueError("operator %s cannot be overwritten again." % operator) + if operator not in Tensor.OVERLOADABLE_OPERATORS: + raise ValueError("Overriding %s is disallowed" % operator) + setattr(Tensor, operator, func) + + def __iter__(self): + """Dummy method to prevent iteration. Do not call. + + NOTE(mrry): If we register __getitem__ as an overloaded operator, + Python will valiantly attempt to iterate over the Tensor from 0 to + infinity. Declaring this method prevents this unintended + behavior. + + Raises: + TypeError: when invoked. + """ + raise TypeError("'Tensor' object is not iterable") + + def eval(self, feed_dict=None, session=None): + """Evaluates this tensor in a `Session`. + + Calling this method will execute all preceding operations that + produce the inputs needed for the operation that produces this + tensor. + + *N.B.* Before invoking `Tensor.eval()`, its graph must have been + launched in a session, and either a default session must be + available, or `session` must be specified explicitly. + + Args: + feed_dict: A dictionary that maps `Tensor` objects to feed values. + See [`Session.run()`](client.md#Session.run) for a description of + the valid feed values. + session: (Optional.) The `Session` to be used to evaluate this tensor. If + none, the default session will be used. + + Returns: + A numpy array corresponding to the value of this tensor. + + """ + return _eval_using_default_session(self, feed_dict, self.graph, session) + + +def _TensorTensorConversionFunction(t, dtype=None, name=None): + _ = name + if dtype and not dtype.is_compatible_with(t.dtype): + raise ValueError( + "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" + % (dtype.name, t.dtype.name, str(t))) + return t + + +_tensor_conversion_func_registry = { + 0: [(Tensor, _TensorTensorConversionFunction)]} + + +def convert_to_tensor(value, dtype=None, name=None): + """Converts the given `value` to a `Tensor`. + + This function converts Python objects of various types to `Tensor` + objects. It accepts `Tensor` objects, numpy arrays, Python lists, + and Python scalars. For example: + + ```python + import numpy as np + array = np.random.rand((32, 100, 100)) + + def my_func(arg): + arg = tf.convert_to_tensor(arg, dtype=tf.float32) + return tf.matmul(arg, arg) + arg + + # The following calls are equivalent. + value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0])) + value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) + value_3 = my_func(numpy.array([[1.0, 2.0], [3.0, 4.0]], dtype=numpy.float32)) + ``` + + This function can be useful when composing a new operation in Python + (such as `my_func` in the example above). All standard Python op + constructors apply this function to each of their Tensor-valued + inputs, which allows those ops to accept numpy arrays, Python lists, + and scalars in addition to `Tensor` objects. + + Args: + value: An object whose type has a registered `Tensor` conversion function. + dtype: Optional element type for the returned tensor. If missing, the + type is inferred from the type of `value`. + name: Optional name to use if a new `Tensor` is created. + + Returns: + A `Tensor` based on `value`. + + Raises: + TypeError: If no conversion function is registered for `value`. + RuntimeError: If a registered conversion function returns an invalid value. + + """ + error_prefix = "" if name is None else "%s: " % name + if dtype is not None: + dtype = types.as_dtype(dtype) + for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()): + for base_type, conversion_func in funcs_at_priority: + if isinstance(value, base_type): + ret = conversion_func(value, dtype=dtype, name=name) + if not isinstance(ret, Tensor): + raise RuntimeError( + "%sConversion function %r for type %s returned non-Tensor: %r" + % (error_prefix, conversion_func, base_type, ret)) + if dtype and not dtype.is_compatible_with(ret.dtype): + raise RuntimeError( + "%sConversion function %r for type %s returned incompatible " + "dtype: requested = %s, actual = %s" + % (error_prefix, conversion_func, base_type, + dtype.name, ret.dtype.name)) + return ret + raise TypeError("%sCannot convert %r with type %s to Tensor: " + "no conversion function registered." + % (error_prefix, value, type(value))) + + +def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): + """Converts the given object to a `Tensor` or an `IndexedSlices`. + + If `value` is an `IndexedSlices` it is returned + unmodified. Otherwise, it is converted to a `Tensor` using + `convert_to_tensor()`. + + Args: + value: An `IndexedSlices` or an object that can be consumed by + `convert_to_tensor()`. + dtype: (Optional.) The required `DType` of the returned `Tensor` or + `IndexedSlices`. + name: (Optional.) A name to use if a new `Tensor` is created. + + Returns: + An `Tensor` or an `IndexedSlices` based on `value`. + + Raises: + ValueError: If `dtype` does not match the element type of `value`. + """ + if isinstance(value, IndexedSlices): + if dtype and not types.AsDType(dtype).is_compatible_with(value.dtype): + raise ValueError( + "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" + % (types.AsDType(dtype).name, value.dtype.name, str(value))) + return value + else: + return convert_to_tensor(value, dtype, name) + + +def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): + """Converts `values` to a list of `Tensor` or `IndexedSlices` objects. + + Args: + values: A list of `None`, `IndexedSlices`, or objects that can be consumed + by `convert_to_tensor()`. + dtype: (Optional.) The required `DType` of the returned `Tensor` + `IndexedSlices`. + + name: (Optional.) A name prefix to used when a new `Tensor` is + created, in which case element `i` will be given the name `name + + '_' + i`. + + Returns: + A list of `Tensor` and/or `IndexedSlices` objects. + + Raises: + TypeError: If no conversion function is registered for an element in + `values`. + RuntimeError: If a registered conversion function returns an invalid + value. + """ + if not isinstance(values, collections.Sequence): + raise TypeError("values must be a list.") + ret = [] + for i, value in enumerate(values): + if value is None: + ret.append(value) + else: + n = None if name is None else "%s_%d" % (name, i) + ret.append( + convert_to_tensor_or_indexed_slices(value, dtype=dtype, name=n)) + return ret + + +def register_tensor_conversion_function(base_type, conversion_func, + priority=100): + """Registers a function for converting objects of base_type to Tensor. + + The conversion function must have the following signature: + + def conversion_func(value, dtype=None, name=None): + # ... + + It must return a Tensor with the given dtype if specified. If the + conversion function creates a new Tensor, it should use the given + name if specified. All exceptions will be propagated to the caller. + + NOTE: The conversion functions will execute in order of priority, + followed by order of registration. To ensure that a conversion + function F runs before another conversion function G, ensure that + F is registered with a smaller priority than G. + + Args: + base_type: The base type or tuple of base types for all objects that + `conversion_func` accepts. + conversion_func: A function that converts instances of base_type to Tensor. + priority: Optional integer that indicates the priority for applying this + conversion function. Conversion functions with smaller priority values + run earlier than conversion functions with larger priority values. + Defaults to 100. + + Raises: + TypeError: If the arguments do not have the appropriate type. + + """ + if not (isinstance(base_type, type) or + (isinstance(base_type, tuple) + and all(isinstance(x, type) for x in base_type))): + raise TypeError("base_type must be a type or a tuple of types.") + if not callable(conversion_func): + raise TypeError("conversion_func must be callable.") + + try: + funcs_at_priority = _tensor_conversion_func_registry[priority] + except KeyError: + funcs_at_priority = [] + _tensor_conversion_func_registry[priority] = funcs_at_priority + funcs_at_priority.append((base_type, conversion_func)) + + +class IndexedSlices(object): + """A sparse representation of a set of tensor slices at given indices. + + This class is a simple wrapper for a pair of `Tensor` objects: + + * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`. + * `indices`: A 1-D integer `Tensor` with shape `[D0]`. + + An `IndexedSlices` is typically used to represent a subset of a larger + tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`. + The values in `indices` are the indices in the first dimension of + the slices that have been extracted from the larger tensor. + + The dense tensor `dense` represented by an `IndexedSlices` `slices` has + + ```python + dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...] + ``` + + The `IndexedSlices` class is used principally in the definition of + gradients for operations that have sparse gradients + (e.g. [`tf.gather`](array_ops.md#gather)). + + Contrast this representation with + [`SparseTensor`](sparse_ops.md#SparseTensor), + which uses multi-dimensional indices and scalar values. + + @@__init__ + + @@values + @@indices + @@dense_shape + + @@name + @@dtype + @@device + @@op + """ + + def __init__(self, values, indices, dense_shape=None): + """Creates an `IndexedSlices`.""" + self._values = values + self._indices = indices + self._dense_shape = dense_shape + + @property + def values(self): + """A `Tensor` containing the values of the slices.""" + return self._values + + @property + def indices(self): + """A 1-D `Tensor` containing the indices of the slices.""" + return self._indices + + @property + def dense_shape(self): + """A 1-D `Tensor` containing the shape of the corresponding dense tensor.""" + return self._dense_shape + + @property + def name(self): + """The name of this `IndexedSlices`.""" + return self.values.name + + @property + def device(self): + """The name of the device on which `values` will be produced, or `None`.""" + return self.values.device + + @property + def op(self): + """The `Operation` that produces `values` as an output.""" + return self.values.op + + @property + def dtype(self): + """The `DType` of elements in this tensor.""" + return self.values.dtype + + def __str__(self): + return "IndexedSlices(indices=%s, values=%s)" % ( + self._indices, self._values) + + +def assert_same_graph(items, expected_graph=None): + """Asserts all items are from the same graph. + + Args: + items: List of graph items (e.g., Variable, Tensor, SparseTensor, + Operation, or IndexedSlices). + expected_graph: Expected graph. If not specified, assert all tensors are + from the same graph. + Returns: + items, for chaining. + Raises: + ValueError: If any graphs do not match. + """ + for item in items: + if not expected_graph: + expected_graph = item.graph + elif expected_graph != item.graph: + raise ValueError("Items must be from the same graph.") + return items + + +class SparseTensor(object): + """Represents a sparse tensor. + + Tensorflow represents a sparse tensor as three separate dense tensors: + `indices`, `values`, and `dense_shape`. In Python, the three tensors are + collected into a `SparseTensor` class for ease of use. If you have separate + `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor` + object before passing to the Ops below. + + Concretely, the sparse tensor `SparseTensor(values, indices, dense_shape)` is + + * `indices`: A 2-D int64 tensor of shape `[N, ndims]`. + * `values`: A 1-D tensor of any type and shape `[N]`. + * `dense_shape`: A 1-D int64 tensor of shape `[ndims]`. + + where `N` and `ndims` are the number of values, and number of dimensions in + the `SparseTensor` respectively. + + The corresponding dense tensor satisfies + + ```python + dense.shape = dense_shape + dense[tuple(indices[i])] = values[i] + ``` + + By convention, `indices` should be sorted in row-major order (or equivalently + lexigraphic order on the tuples `indices[i]`). This is not enforced when + `SparseTensor` objects are constructed, but most Ops assume correct ordering. + If the ordering is wrong, it can be fixed by calling `sparse_reorder` on the + misordered `SparseTensor`. + + Example: The sparse tensor + + ```python + SparseTensor(values=[1, 2], indices=[[0, 0], [1, 2]], shape=[3, 4]) + ``` + + represents the dense tensor + + ```python + [[1, 0, 0, 0] + [0, 0, 2, 0] + [0, 0, 0, 0]] + ``` + + @@__init__ + @@indices + @@values + @@dtype + @@shape + @@graph + """ + + def __init__(self, indices, values, shape): + """Creates a `SparseTensor`. + + Args: + indices: A 2-D int64 tensor of shape `[N, ndims]`. + values: A 1-D tensor of any type and shape `[N]`. + dense_shape: A 1-D int64 tensor of shape `[ndims]`. + + Returns: + A `SparseTensor` + """ + with op_scope([indices, values, shape], None, "SparseTensor"): + indices = convert_to_tensor(indices, name="indices") + values = convert_to_tensor(values, name="values") + shape = convert_to_tensor(shape, name="shape") + self._indices = indices + self._values = values + self._shape = shape + + indices_shape = indices.get_shape().with_rank(2) + values_shape = values.get_shape().with_rank(1) + shape_shape = shape.get_shape().with_rank(1) + + # Assert number of rows in indices match the number of elements in values. + indices_shape[0].merge_with(values_shape[0]) + # Assert number of columns in indices matches the number of elements in + # shape. + indices_shape[1].merge_with(shape_shape[0]) + + @property + def indices(self): + """The indices of non-zero values in the represented dense tensor. + + Returns: + A 2-D Tensor of int64 with shape `[N, ndims]`, where `N` is the + number of non-zero values in the tensor, and `ndims` is the rank. + """ + return self._indices + + @property + def values(self): + """The non-zero values in the represented dense tensor. + + Returns: + A 1-D Tensor of any data type. + """ + return self._values + + @property + def dtype(self): + """The `DType` of elements in this tensor.""" + return self._values.dtype + + @property + def shape(self): + """A 1-D Tensor of int64 representing the shape of the dense tensor.""" + return self._shape + + @property + def graph(self): + """The `Graph` that contains the index, value, and shape tensors.""" + return self._indices.graph + + def __str__(self): + return "SparseTensor(indices=%s, values=%s, shape=%s)" % ( + self._indices, self._values, self._shape) + + +SparseTensorValue = collections.namedtuple("SparseTensorValue", + ["indices", "values", "shape"]) + + +def _device_string(dev_spec): + if isinstance(dev_spec, pydev.Device): + return dev_spec.to_string() + else: + return dev_spec + + +def _NodeDef(op_type, name, device=None, attrs=None): + """Create a NodeDef proto. + + Args: + op_type: Value for the "op" attribute of the NodeDef proto. + name: Value for the "name" attribute of the NodeDef proto. + device: string, device, or function from NodeDef to string. + Value for the "device" attribute of the NodeDef proto. + attrs: optional list for the "attr" attribute of the NodeDef proto. + + Returns: + A graph_pb2.NodeDef protocol buffer. + """ + node_def = graph_pb2.NodeDef() + node_def.op = str(op_type) + node_def.name = str(name) + if attrs is not None: + for k, v in attrs.iteritems(): + node_def.attr[k].CopyFrom(v) + if device is not None: + if callable(device): + node_def.device = device(node_def) + else: + node_def.device = _device_string(device) + return node_def + + +# Copied from core/framework/node_def_util.cc +# TODO(mrry,josh11b): Consolidate this validation in C++ code. +_VALID_OP_NAME_REGEX = re.compile("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*") + + +class Operation(object): + """Represents a graph node that performs computation on tensors. + + An `Operation` is a node in a TensorFlow `Graph` that takes zero or + more `Tensor` objects as input, and produces zero or more `Tensor` + objects as output. Objects of type `Operation` are created by + calling a Python op constructor (such as [`tf.matmul()`](math_ops.md#matmul)) + or [`Graph.create_op()`](framework.md#Graph.create_op). + + For example `c = tf.matmul(a, b)` creates an `Operation` of type + "MatMul" that takes tensors `a` and `b` as input, and produces `c` + as output. + + After the graph has been launched in a session, an `Operation` can + be executed by passing it to [`Session.run()`](client.md#Session.run). + `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. + + @@name + @@type + @@inputs + @@control_inputs + @@outputs + @@device + @@graph + + @@run + + @@get_attr + @@traceback + """ + + def __init__(self, node_def, g, inputs=None, output_types=None, + control_inputs=None, input_types=None, original_op=None, + op_def=None): + """Creates an `Operation`. + + NOTE: This constructor validates the name of the Operation (passed + as "node_def.name"). Valid Operation names match the following + regular expression: + + [A-Za-z0-9.][A-Za-z0-9_.\\-/]* + + Args: + node_def: graph_pb2.NodeDef. NodeDef for the Operation. + Used for attributes of graph_pb2.NodeDef, typically "name", + "op", and "device". The "input" attribute is irrelevant here + as it will be computed when generating the model. + g: Graph. The parent graph. + inputs: list of Tensor objects. The inputs to this Operation. + output_types: list of types_pb2.DataType. List of the types of the + Tensors computed by this operation. The length of this list indicates + the number of output endpoints of the Operation. + control_inputs: list of operations or tensors from which to have a + control dependency. + input_types: List of types_pb2.DataType representing the + types of the Tensors accepted by the Operation. By default + uses [x.dtype.base_dtype for x in inputs]. Operations that expect + reference-typed inputs must specify these explicitly. + original_op: Optional. Used to associate the new Operation with an + existing Operation (for example, a replica with the op that was + replicated). + op_def: Optional. The op_def_pb2.OpDef proto that describes the + op type that this Operation represents. + + Raises: + TypeError: if control inputs are not Operations or Tensors, + or if node_def is not a NodeDef, + or if g is not a Graph, + or if inputs are not Tensors, + or if inputs and input_types are incompatible. + ValueError: if the node_def name is not valid. + """ + if not isinstance(node_def, graph_pb2.NodeDef): + raise TypeError("node_def needs to be a NodeDef: %s" % node_def) + if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0: + raise ValueError( + "Cannot create an Operation with a NodeDef larger than 2GB.") + if not _VALID_OP_NAME_REGEX.match(node_def.name): + raise ValueError("'%s' is not a valid node name" % node_def.name) + if not isinstance(g, Graph): + raise TypeError("g needs to be a Graph: %s" % g) + self._node_def = copy.deepcopy(node_def) + self._graph = g + if inputs is None: + inputs = [] + self._inputs = inputs + for a in self._inputs: + if not isinstance(a, Tensor): + raise TypeError("input needs to be a Tensor: %s" % a) + # Mark that we consume the inputs. + a._add_consumer(self) # pylint: disable=protected-access + if output_types is None: + output_types = [] + self._output_types = output_types + self._outputs = [Tensor(self, i, output_types[i]) + for i in xrange(len(output_types))] + if input_types is None: + input_types = [i.dtype.base_dtype for i in self._inputs] + else: + if not all(x.is_compatible_with(i.dtype) + for i, x in zip(self._inputs, input_types)): + raise TypeError("Inputs are not compatible with input types") + self._input_types = input_types + + # Build the list of control inputs. + self._control_inputs = [] + if control_inputs: + for c in control_inputs: + c_op = None + if isinstance(c, Operation): + c_op = c + elif isinstance(c, (Tensor, IndexedSlices)): + c_op = c.op + else: + raise TypeError("Control input must be an Operation, " + "a Tensor, or IndexedSlices: %s" % c) + self._control_inputs.append(c_op) + + self._original_op = original_op + self._op_def = op_def + self._traceback = _extract_stack() + # Add this op to the current control flow context: + self._control_flow_context = g._get_control_flow_context() + if g._get_control_flow_context() is not None: + g._get_control_flow_context().AddOp(self) + # NOTE(keveman): Control flow context's AddOp could be creating new ops and + # setting op.inputs[index] = new_op. Thus the new ops' id could be larger + # than this op's id even though this op depend on them. Therefore, delaying + # assigning id to this op until all ops this could be dependent on are + # created. + self._id_value = self._graph._next_id() # pylint: disable=protected-access + self._recompute_node_def() + + def values(self): + """DEPRECATED: Use outputs.""" + return tuple(self.outputs) + + def _get_control_flow_context(self): + """Returns the current control flow context. + + Returns: + A context object. + """ + return self._control_flow_context + + @property + def name(self): + """The full name of this operation.""" + return self._node_def.name + + @property + def _id(self): + """The unique integer id of this operation.""" + return self._id_value + + @property + def device(self): + """The name of the device to which this op has been assigned, if any. + + Returns: + The string name of the device to which this op has been + assigned, or None if it has not been assigned to a device. + """ + dev = self._node_def.device + return None if not dev else dev + + def _set_device(self, device): + """Set the device of this operation. + + Args: + device: string or device.. The device to set. + """ + self._node_def.device = _device_string(device) + + def _add_input(self, tensor, dtype=None): + """Add a new input to this operation. + + Args: + tensor: the Tensor to add as an input. + dtype: types.DType: type of the input; defaults to + the tensor's dtype. + + Raises: + TypeError: if tensor is not a Tensor, + or if input tensor type is not convertible to dtype. + ValueError: if the Tensor is from a different graph. + """ + if not isinstance(tensor, Tensor): + raise TypeError("tensor must be a Tensor: %s" % tensor) + assert_same_graph([self, tensor]) + if dtype is None: + dtype = tensor.dtype + else: + dtype = types.as_dtype(dtype) + if not dtype.is_compatible_with(tensor.dtype): + raise TypeError( + "Cannot convert a tensor of type %s to an input of type %s" + % (tensor.dtype.name, dtype.name)) + self._inputs.append(tensor) + self._input_types.append(dtype) + tensor._add_consumer(self) # pylint: disable=protected-access + self._recompute_node_def() + + def _update_input(self, index, tensor, dtype=None): + """Update the input to this operation at the given index. + + NOTE: This is for TF internal use only. Please don't use it. + + Args: + index: the index of the input to update. + tensor: the Tensor to be used as the input at the given index. + dtype: types.DType: type of the input; defaults to + the tensor's dtype. + + Raises: + TypeError: if tensor is not a Tensor, + or if input tensor type is not convertible to dtype. + ValueError: if the Tensor is from a different graph. + """ + if not isinstance(tensor, Tensor): + raise TypeError("tensor must be a Tensor: %s" % tensor) + assert_same_graph([self, tensor]) + if dtype is None: + dtype = tensor.dtype + else: + dtype = types.as_dtype(dtype) + if not dtype.is_compatible_with(tensor.dtype): + raise TypeError( + "Cannot convert a tensor of type %s to an input of type %s" + % (tensor.dtype.name, dtype.name)) + + self._inputs[index].consumers().remove(self) + self._inputs[index] = tensor + self._input_types[index] = dtype + tensor._add_consumer(self) # pylint: disable=protected-access + self._recompute_node_def() + + def _add_control_input(self, op): + """Add a new control input to this operation. + + Args: + op: the Operation to add as control input. + + Raises: + TypeError: if op is not an Operation. + ValueError: if op is from a different graph. + """ + if not isinstance(op, Operation): + raise TypeError("op must be an Operation: %s" % op) + assert_same_graph([self, op]) + self._control_inputs.append(op) + self._recompute_node_def() + + # Methods below are used when building the NodeDef and Graph proto. + def _recompute_node_def(self): + del self._node_def.input[:] + self._node_def.input.extend([t._as_node_def_input() for t in self._inputs]) + if self._control_inputs: + self._node_def.input.extend(["^%s" % op.name for op in + self._control_inputs]) + + def __str__(self): + return str(self._node_def) + + @property + def outputs(self): + """The list of `Tensor` objects representing the outputs of this op.""" + return self._outputs + +# pylint: disable=protected-access + class _InputList(object): + """Immutable input list wrapper.""" + + def __init__(self, op): + self._op = op + + def __iter__(self): + return iter(self._op._inputs) + + def __len__(self): + return len(self._op._inputs) + + def __bool__(self): + return bool(self._op._inputs) + + def __getitem__(self, i): + return self._op._inputs[i] +# pylint: enable=protected-access + + @property + def inputs(self): + """The list of `Tensor` objects representing the data inputs of this op.""" + return Operation._InputList(self) + + @property + def _input_dtypes(self): + return self._input_types + + @property + def control_inputs(self): + """The `Operation` objects on which this op has a control dependency. + + Before this op is executed, TensorFlow will ensure that the + operations in `self.control_inputs` have finished executing. This + mechanism can be used to run ops sequentially for performance + reasons, or to ensure that the side effects of an op are observed + in the correct order. + + Returns: + A list of `Operation` objects. + + """ + return self._control_inputs + + @property + def type(self): + """The type of the op (e.g. `"MatMul"`).""" + return self._node_def.op + + @property + def graph(self): + """The `Graph` that contains this operation.""" + return self._graph + + @property + def node_def(self): + """Returns a serialized `NodeDef` representation of this operation. + + Returns: + A + [`NodeDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto) + protocol buffer. + """ + return self._node_def + + @property + def op_def(self): + """Returns the `OpDef` proto that represents the type of this op. + + Returns: + An + [`OpDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_def.proto) + protocol buffer. + """ + return self._op_def + + @property + def traceback(self): + """Returns the call stack from when this operation was constructed.""" + return _convert_stack(self._traceback) + + def get_attr(self, name): + """Returns the value of the attr of this op with the given `name`. + + Args: + name: The name of the attr to fetch. + + Returns: + The value of the attr, as a Python object. + + Raises: + ValueError: If this op does not have an attr with the given `name`. + """ + fields = ["s", "i", "f", "b", "type", "shape", "tensor"] + if name not in self._node_def.attr: + raise ValueError("No attr named '" + name + "' in " + + str(self._node_def)) + x = self._node_def.attr[name] + # Treat an empty oneof value as an empty list. + if not x.WhichOneof("value"): + return [] + if x.HasField("list"): + for f in fields: + if getattr(x.list, f): + return list(getattr(x.list, f)) + return [] + else: + for f in fields: + if x.HasField(f): + return getattr(x, f) + assert False, "Unsupported field type in " + str(x) + + def run(self, feed_dict=None, session=None): + """Runs this operation in a `Session`. + + Calling this method will execute all preceding operations that + produce the inputs needed for this operation. + + *N.B.* Before invoking `Operation.run()`, its graph must have been + launched in a session, and either a default session must be + available, or `session` must be specified explicitly. + + Args: + feed_dict: A dictionary that maps `Tensor` objects to feed values. + See [`Session.run()`](client.md#Session.run) for a description of the + valid feed values. + session: (Optional.) The `Session` to be used to run to this operation. If + none, the default session will be used. + """ + _run_using_default_session(self, feed_dict, self.graph, session) + + +_gradient_registry = registry.Registry("gradient") + + +class RegisterGradient(object): + """A decorator for registering the gradient function for an op type. + + This decorator is only used when defining a new op type. For an op + with `m` inputs and `n` inputs, the gradient function is a function + that takes the original `Operation` and `n` `Tensor` objects + (representing the gradients with respect to each output of the op), + and returns `m` `Tensor` objects (representing the partial gradients + with respect to each input of the op). + + For example, assuming that operations of type `"Sub"` take two + inputs `x` and `y`, and return a single output `x - y`, the + following gradient function would be registered: + + ```python + @tf.RegisterGradient("Sub") + def _sub_grad(unused_op, grad): + return grad, tf.Neg(grad) + ``` + + The decorator argument `op_type` is the string type of an + operation. This corresponds to the `OpDef.name` field for the proto + that defines the operation. + + @@__init__ + """ + + def __init__(self, op_type): + """Creates a new decorator with `op_type` as the Operation type. + + Args: + op_type: The string type of an operation. This corresponds to the + `OpDef.name` field for the proto that defines the operation. + """ + if not isinstance(op_type, basestring): + raise TypeError("op_type must be a string") + self._op_type = op_type + + def __call__(self, f): + """Registers the function `f` as gradient function for `op_type`.""" + _gradient_registry.register(f, self._op_type) + return f + + +def NoGradient(op_type): + """Specifies that ops of type `op_type` do not have a defined gradient. + + This function is only used when defining a new op type. It may be + used for ops such as `tf.size()` that are not differentiable. For + example: + + ```python + tf.NoGradient("Size") + ``` + + Args: + op_type: The string type of an operation. This corresponds to the + `OpDef.name` field for the proto that defines the operation. + + Raises: + TypeError: If `op_type` is not a string. + + """ + if not isinstance(op_type, basestring): + raise TypeError("op_type must be a string") + _gradient_registry.register(None, op_type) + + +def get_gradient_function(op): + """Returns the function that computes gradients for "op".""" + if not op.inputs: return None + try: + op_type = op.get_attr("_gradient_op_type") + except ValueError: + op_type = op.type + return _gradient_registry.lookup(op_type) + + +_shape_registry = registry.Registry("shape functions") +_default_shape_function_registry = registry.Registry("default shape functions") + +class RegisterShape(object): + """A decorator for registering the shape function for an op type. + + This decorator is only used when defining a new op type. A shape + function is a function from an `Operation` object to a list of + `TensorShape` objects, with one `TensorShape` for each output of the + operation. + + For example, assuming that operations of type `"Sub"` take two + inputs `x` and `y`, and return a single output `x - y`, all with the + same shape, the following shape function would be registered: + + ```python + @tf.RegisterShape("Sub") + def _sub_shape(op): + return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())] + ``` + + The decorator argument `op_type` is the string type of an + operation. This corresponds to the `OpDef.name` field for the proto + that defines the operation. + + """ + + def __init__(self, op_type): + """Saves the "op_type" as the Operation type.""" + if not isinstance(op_type, basestring): + raise TypeError("op_type must be a string") + self._op_type = op_type + + def __call__(self, f): + """Registers "f" as the shape function for "op_type".""" + if f is None: + # None is a special "weak" value that provides a default shape function, + # and can be overridden by a non-None registration. + try: + _default_shape_function_registry.register(_no_shape_function, + self._op_type) + except KeyError: + # Ignore duplicate registrations of the weak value. This can + # occur if the op library input to wrapper generation + # inadvertently links in one or more of the standard op + # libraries. + pass + else: + _shape_registry.register(f, self._op_type) + return f + + +def _no_shape_function(op): + return [tensor_shape.unknown_shape() for _ in op.outputs] + + +def set_shapes_for_outputs(op): + """Uses the registered shape functions to set the shapes for op's outputs.""" + try: + shape_func = _shape_registry.lookup(op.type) + except LookupError: + try: + shape_func = _default_shape_function_registry.lookup(op.type) + except LookupError: + raise RuntimeError("No shape function registered for standard op: %s" + % op.type) + shapes = shape_func(op) + if len(op.outputs) != len(shapes): + raise RuntimeError( + "Shape function for op %s returned %g shapes but expecting %g" % + (op, len(op.outputs), len(shapes))) + for output, s in zip(op.outputs, shapes): + output.set_shape(s) + + +class Graph(object): + """A TensorFlow computation, represented as a dataflow graph. + + A `Graph` contains a set of [`Operation`](framework.md#Operation) objects, + which represent units of computation; and [`Tensor`](framework.md#Tensor) + objects, which represent the units of data that flow between operations. + + A default `Graph` is always registered, and accessible by calling + [`tf.get_default_graph()`](framework.md#get_default_graph). To add an + operation to the default graph, simply call one of the functions that defines + a new `Operation`: + + ``` + c = tf.constant(4.0) + assert c.graph is tf.get_default_graph() + ``` + + Another typical usage involves the + [`Graph.as_default()`](framework.md#Graph.as_default) + context manager, which overrides the current default graph for the + lifetime of the context: + + ```python + g = tf.Graph() + with g.as_default(): + # Define operations and tensors in `g`. + c = tf.constant(30.0) + assert c.graph is g + ``` + + Important note: This class *is not* thread-safe for graph construction. All + operations should be created from a single thread, or external + synchronization must be provided. Unless otherwise specified, all methods + are not thread-safe. + + @@__init__ + @@as_default + @@as_graph_def + @@finalize + @@finalized + + @@control_dependencies + @@device + @@name_scope + + A `Graph` instance supports an arbitrary number of "collections" + that are identified by name. For convenience when building a large + graph, collections can store groups of related objects: for + example, the `tf.Variable` uses a collection (named + [`tf.GraphKeys.VARIABLES`](framework.md#GraphKeys)) for all variables that are + created during the construction of a graph. The caller may define + additional collections by specifying a new name. + + @@add_to_collection + @@get_collection + + @@as_graph_element + @@get_operation_by_name + @@get_tensor_by_name + @@get_operations + + @@get_default_device + @@seed + @@unique_name + @@version + + @@create_op + @@gradient_override_map + """ + + def __init__(self): + """Creates a new, empty Graph.""" + self._nodes_by_id = dict() + self._next_node_id = [dict()] + self._next_id_counter = 0 + self._nodes_by_name = dict() + # Current name stack: a pair of uniquified names and plain names. + self._name_stack = ("", "") + # Maps a name used in the graph to the next id to use for that name. + self._names_in_use = {} + # Default device applied to new ops. + self._default_device = None + # Functions that will be applied to choose a device if none is specified. + self._device_function_stack = [] + # Default original_op applied to new ops. + self._default_original_op = None + # Current control flow context. It could be either CondContext or + # WhileContext defined in ops/control_flow_ops.py + self._control_flow_context = None + # A new node will depend of the union of all of the nodes in the stack. + self._control_dependencies_stack = [] + # Arbritrary collections of objects. + self._collections = {} + # The graph-level random seed + self._seed = None + # A map from op type to the kernel label that should be used. + self._op_to_kernel_label_map = {} + # A map from op type to an alternative op type that should be used when + # computing gradients. + self._gradient_override_map = {} + # True if the graph is considered "finalized". In that case no + # new operations can be added. + self._finalized = False + + def _check_not_finalized(self): + """Check if the graph is finalized. + + Raises: + RuntimeError: If the graph finalized. + """ + if self._finalized: + raise RuntimeError("Graph is finalized and cannot be modified.") + + def _add_op(self, op): + """Adds 'op' to the graph. + + Args: + op: the Operator or Tensor to add. + + Raises: + TypeError: if op is not an Operation or Tensor. + ValueError: if the op.name or op._id are already used. + """ + self._check_not_finalized() + if not isinstance(op, (Tensor, Operation)): + raise TypeError("op must be a Tensor or Operation: %s" % op) + + if op._id in self._nodes_by_id: + raise ValueError("cannot add an op with id %d as it already " + "exists in the graph" % op._id) + if op.name in self._nodes_by_name: + raise ValueError("cannot add op with name %s as that name " + "is already used" % op.name) + self._nodes_by_id[op._id] = op + self._nodes_by_name[op.name] = op + + @property + def version(self): + """Returns a version number that increases as ops are added to the graph.""" + return self._next_id_counter + + @property + def seed(self): + return self._seed + + @seed.setter + def seed(self, seed): + self._seed = seed + + @property + def finalized(self): + """True if this graph has been finalized.""" + return self._finalized + + def finalize(self): + """Finalizes this graph, making it read-only. + + After calling `g.finalize()`, no new operations can be added to + `g`. This method is used to ensure that no operations are added + to a graph when it is shared between multiple threads, for example + when using a [`QueueRunner`](train.md#QueueRunner). + """ + self._finalized = True + + def _get_control_flow_context(self): + """Returns the current control flow context. + + Returns: + A context object. + """ + return self._control_flow_context + + def _set_control_flow_context(self, context): + """Sets the current control flow context. + + Args: + context: a context object. + """ + self._control_flow_context = context + + def as_graph_def(self, from_version=None): + """Returns a serialized `GraphDef` representation of this graph. + + This method is thread-safe. + + Args: + from_version: Optional. If this is set, returns a `GraphDef` + containing only the nodes that were added to this graph since + its `version` property had the given value. + + Returns: + A + [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto) + protocol buffer. + """ + graph = graph_pb2.GraphDef() + bytesize = 0 + for op_id in sorted(self._nodes_by_id): + op = self._nodes_by_id[op_id] + if from_version is None or op_id > from_version: + graph.node.extend([op.node_def]) + bytesize += op.node_def.ByteSize() + if bytesize >= (1 << 31) or bytesize < 0: + raise ValueError("GraphDef cannot be larger than 2GB.") + return graph + + # Helper functions to create operations. + def create_op(self, op_type, inputs, dtypes, + input_types=None, name=None, attrs=None, op_def=None, + compute_shapes=True): + """Creates an `Operation` in this graph. + + This is a low-level interface for creating an `Operation`. Most + programs will not call this method directly, and instead use the + Python op constructors, such as `tf.constant()`, which add ops to + the default graph. + + Args: + op_type: The `Operation` type to create. This corresponds to the + `OpDef.name` field for the proto that defines the operation. + inputs: A list of `Tensor` objects that will be inputs to the `Operation`. + dtypes: A list of `DType` objects that will be the types of the tensors + that the operation produces. + input_types: (Optional.) A list of `DType`s that will be the types of + the tensors that the operation consumes. By default, uses the base + `DType` of each input in `inputs`. Operations that expect + reference-typed inputs must specify `input_types` explicitly. + name: (Optional.) A string name for the operation. If not specified, a + name is generated based on `op_type`. + attrs: (Optional.) A list of `AttrValue` protos for the `attr` field of + the `NodeDef` proto that will represent the operation. + op_def: (Optional.) The `OpDef` proto that describes the `op_type` that + the operation will have. + compute_shapes: (Optional.) If True, shape inference will be performed + to compute the shapes of the outputs. + + Raises: + TypeError: if any of the inputs is not a `Tensor`. + + Returns: + An `Operation` object. + + """ + self._check_not_finalized() + for idx, a in enumerate(inputs): + if not isinstance(a, Tensor): + raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) + if name is None: + name = op_type + # If a names ends with a '/' it is a "name scope" and we use it as-is, + # after removing the trailing '/'. + if name and name[-1] == "/": + name = name[:-1] + else: + name = self.unique_name(name) + + node_def = _NodeDef( + op_type, name, device=self._default_device or None, attrs=attrs) + + # Apply a kernel label if one has been specified for this op_type. + try: + kernel_label = self._op_to_kernel_label_map[op_type] + node_def.attr["_kernel"].CopyFrom( + attr_value_pb2.AttrValue(s=kernel_label)) + except KeyError: + pass + + # Apply the overriding op_type for gradients if one has been + # specified for this op_type. + try: + mapped_op_type = self._gradient_override_map[op_type] + node_def.attr["_gradient_op_type"].CopyFrom( + attr_value_pb2.AttrValue(s=mapped_op_type)) + except KeyError: + pass + + control_inputs = self._control_dependencies_for_inputs(inputs) + ret = Operation(node_def, self, inputs=inputs, output_types=dtypes, + control_inputs=control_inputs, input_types=input_types, + original_op=self._default_original_op, op_def=op_def) + if compute_shapes: + set_shapes_for_outputs(ret) + self._add_op(ret) + self._record_op_seen_by_control_dependencies(ret) + # Apply any device functions in reverse order, so that the most recently + # pushed function has the first chance to apply a device to the op. + # We apply here because the result can depend on the Operation's + # signature, which is computed in the Operation constructor. + for device_function in reversed(self._device_function_stack): + ret._set_device(device_function(ret)) + return ret + + def as_graph_element(self, obj, allow_tensor=True, allow_operation=True): + """Returns the object referred to by `obj`, as an `Operation` or `Tensor`. + + This function validates that `obj` represents an element of this + graph, and gives an informative error message if it is not. + + This function is the canonical way to get/validate an object of + one of the allowed types from an external argument reference in the + Session API. + + This method may be called concurrently from multiple threads. + + Args: + obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. + Can also be any object with an `_as_graph_element()` method that returns + a value of one of these types. + allow_tensor: If true, `obj` may refer to a `Tensor`. + allow_operation: If true, `obj` may refer to an `Operation`. + + Returns: + The `Tensor` or `Operation` in the Graph corresponding to `obj`. + + Raises: + TypeError: If `obj` is not a type we support attempting to convert + to types. + ValueError: If `obj` is of an appropriate type but invalid. For + example, an invalid string. + KeyError: If `obj` is not an object in the graph. + """ + + # The vast majority of this function is figuring + # out what an API user might be doing wrong, so + # that we can give helpful error messages. + # + # Ideally, it would be nice to split it up, but we + # need context to generate nice error messages. + + if allow_tensor and allow_operation: + types_str = "Tensor or Operation" + elif allow_tensor: + types_str = "Tensor" + elif allow_operation: + types_str = "Operation" + else: + raise ValueError("allow_tensor and allow_operation can't both be False.") + + conv_fn = getattr(obj, "_as_graph_element", None) + if conv_fn and callable(conv_fn): + obj = conv_fn() + + # If obj appears to be a name... + if isinstance(obj, basestring): + name = obj + + if ":" in name and allow_tensor: + # Looks like a Tensor name and can be a Tensor. + try: + op_name, out_n = name.split(":") + out_n = int(out_n) + except: + raise ValueError("The name %s looks a like a Tensor name, but is " + "not a valid one. Tensor names must be of the " + "form \"<op_name>:<output_index>\"." % repr(name)) + if op_name in self._nodes_by_name: + op = self._nodes_by_name[op_name] + else: + raise KeyError("The name %s refers to a Tensor which does not " + "exist. The operation, %s, does not exist in the " + "graph." % (repr(name), repr(op_name))) + try: + return op.outputs[out_n] + except: + raise KeyError("The name %s refers to a Tensor which does not " + "exist. The operation, %s, exists but only has " + "%s outputs." + % (repr(name), repr(op_name), len(op.outputs))) + + elif ":" in name and not allow_tensor: + # Looks like a Tensor name but can't be a Tensor. + raise ValueError("Name %s appears to refer to a Tensor, not a %s." + % (repr(name), types_str)) + + elif ":" not in name and allow_operation: + # Looks like an Operation name and can be an Operation. + if name not in self._nodes_by_name: + raise KeyError("The name %s refers to an Operation not in the " + "graph." % repr(name)) + return self._nodes_by_name[name] + + elif ":" not in name and not allow_operation: + # Looks like an Operation name but can't be an Operation. + if name in self._nodes_by_name: + # Yep, it's an Operation name + err_msg = ("The name %s refers to an Operation, not a %s." + % (repr(name), types_str)) + else: + err_msg = ("The name %s looks like an (invalid) Operation name, " + "not a %s." % (repr(name), types_str)) + err_msg += (" Tensor names must be of the form " + "\"<op_name>:<output_index>\".") + raise ValueError(err_msg) + + elif isinstance(obj, Tensor) and allow_tensor: + # Actually obj is just the object it's referring to. + return obj + elif isinstance(obj, Operation) and allow_operation: + # Actually obj is just the object it's referring to. + return obj + else: + # We give up! + raise TypeError("Can not convert a %s into a %s." + % (type(obj).__name__, types_str)) + + def get_operations(self): + """Return the list of operations in the graph. + + You can modify the operations in place, but modifications + to the list such as inserts/delete have no effect on the + list of operations known to the graph. + + This method may be called concurrently from multiple threads. + + Returns: + A list of Operations. + """ + return self._nodes_by_id.values() + + def get_operation_by_name(self, name): + """Returns the `Operation` with the given `name`. + + This method may be called concurrently from multiple threads. + + Args: + name: The name of the `Operation` to return. + + Returns: + The `Operation` with the given `name`. + + Raises: + TypeError: If `name` is not a string. + KeyError: If `name` does not correspond to an operation in this graph. + """ + + if not isinstance(name, basestring): + raise TypeError("Operation names are strings (or similar), not %s." + % type(name).__name__) + return self.as_graph_element(name, allow_tensor=False, allow_operation=True) + + def get_tensor_by_name(self, name): + """Returns the `Tensor` with the given `name`. + + This method may be called concurrently from multiple threads. + + Args: + name: The name of the `Tensor` to return. + + Returns: + The `Tensor` with the given `name`. + + Raises: + TypeError: If `name` is not a string. + KeyError: If `name` does not correspond to a tensor in this graph. + """ + # Names should be strings. + if not isinstance(name, basestring): + raise TypeError("Tensor names are strings (or similar), not %s." + % type(name).__name__) + return self.as_graph_element(name, allow_tensor=True, allow_operation=False) + + def _next_id(self): + """Id for next Operation instance. Also increments the internal id.""" + self._check_not_finalized() + self._next_id_counter += 1 + return self._next_id_counter + + @property + def _last_id(self): + return self._next_id_counter + + def as_default(self): + """Returns a context manager that makes this `Graph` the default graph. + + This method should be used if you want to create multiple graphs + in the same process. For convenience, a global default graph is + provided, and all ops will be added to this graph if you do not + create a new graph explicitly. Use this method the `with` keyword + to specify that ops created within the scope of a block should be + added to this graph. + + The default graph is a property of the current thread. If you + create a new thread, and wish to use the default graph in that + thread, you must explicitly add a `with g.as_default():` in that + thread's function. + + The following code examples are equivalent: + + ```python + # 1. Using Graph.as_default(): + g = tf.Graph() + with g.as_default(): + c = tf.constant(5.0) + assert c.graph is g + + # 2. Constructing and making default: + with tf.Graph().as_default() as g: + c = tf.constant(5.0) + assert c.graph is g + ``` + + Returns: + A context manager for using this graph as the default graph. + """ + return _default_graph_stack.get_controller(self) + + def add_to_collection(self, name, value): + """Stores `value` in the collection with the given `name`. + + Args: + name: The key for the collection. For example, the `GraphKeys` class + contains many standard names for collections. + value: The value to add to the collection. + """ + self._check_not_finalized() + if name not in self._collections: + self._collections[name] = [value] + else: + self._collections[name].append(value) + + def get_collection(self, name, scope=None): + """Returns a list of values in the collection with the given `name`. + + Args: + key: The key for the collection. For example, the `GraphKeys` class + contains many standard names for collections. + scope: (Optional.) If supplied, the resulting list is filtered to include + only items whose name begins with this string. + + Returns: + The list of values in the collection with the given `name`, or + an empty list if no value has been added to that collection. The + list contains the values in the order under which they were + collected. + """ + if scope is None: + return self._collections.get(name, list()) + else: + c = [] + for item in self._collections.get(name, list()): + if hasattr(item, 'name') and item.name.startswith(scope): + c.append(item) + return c + + @contextlib.contextmanager + def _original_op(self, op): + """Python 'with' handler to help annotate ops with their originator. + + An op may have an 'original_op' property that indicates the op on which + it was based. For example a replica op is based on the op that was + replicated and a gradient op is based on the op that was differentiated. + + All ops created in the scope of this 'with' handler will have + the given 'op' as their original op. + + Args: + op: The Operation that all ops created in this scope will have as their + original op. + + Yields: + Nothing. + """ + old_original_op = self._default_original_op + try: + self._default_original_op = op + yield + finally: + self._default_original_op = old_original_op + + # pylint: disable=g-doc-return-or-yield + @contextlib.contextmanager + def name_scope(self, name): + """Returns a context manager that creates hierarchical names for operations. + + A graph maintains a stack of name scopes. A `with name_scope(...):` + statement pushes a new name onto the stack for the lifetime of the context. + + The `name` argument will be interpreted as follows: + + * A string (not ending with '/') will create a new name scope, in which + `name` is appended to the prefix of all operations created in the + context. If `name` has been used before, it will be made unique by + calling `self.unique_name(name)`. + * A scope previously captured from a `with g.name_scope(...) as + scope:` statement will be treated as an "absolute" name scope, which + makes it possible to re-enter existing scopes. + * A value of `None` or the empty string will reset the current name scope + to the top-level (empty) name scope. + + For example: + + ```python + with tf.Graph().as_default() as g: + c = tf.constant(5.0, name="c") + assert c_1.name == "c" + c_1 = tf.constant(6.0, name="c") + assert c_1.name == "c_1" + + # Creates a scope called "nested" + with g.name_scope("nested") as scope: + nested_c = tf.constant(10.0, name="c") + assert nested_c.name == "nested/c" + + # Creates a nested scope called "inner". + with g.name_scope("inner"): + nested_inner_c = tf.constant(20.0, name="c") + assert nested_inner_c.name == "nested/inner/c" + + # Create a nested scope called "inner_1". + with g.name_scope("inner"): + nested_inner_1_c = tf.constant(30.0, name="c") + assert nested_inner_1_c.name == "nested/inner_1/c" + + # Treats `scope` as an absolute name scope, and + # switches to the "nested/" scope. + with g.name_scope(scope): + nested_d = tf.constant(40.0, name="d") + assert nested_d.name == "nested/d" + + with g.name_scope(""): + e = tf.constant(50.0, name="e") + assert e.name == "e" + ``` + + The name of the scope itself can be captured by `with + g.name_scope(...) as scope:`, which stores the name of the scope + in the variable `scope`. This value can be used to name an + operation that represents the overall result of executing the ops + in a scope. For example: + + ```python + inputs = tf.constant(...) + with g.name_scope('my_layer') as scope: + weights = tf.Variable(..., name="weights") + biases = tf.Variable(..., name="biases") + affine = tf.matmul(inputs, weights) + biases + output = tf.nn.relu(affine, name=scope) + ``` + + + Args: + name: A name for the scope. + + Returns: + A context manager that installs `name` as a new name scope. + """ + try: + old_stack = self._name_stack + if not name: # Both for name=None nad name="" we re-set to empty scope. + new_stack = (None, None) + elif name and name[-1] == "/": + new_stack = (name[:-1], name[:-1]) + else: + new_stack = (self.unique_name(name), self._plain_name(name)) + self._name_stack = new_stack + yield "" if new_stack[0] is None else new_stack[0] + "/" + finally: + self._name_stack = old_stack + # pylint: enable=g-doc-return-or-yield + + def unique_name(self, name): + """Return a unique Operation name for "name". + + Note: You rarely need to call unique_name() directly. Most of the time you + just need to create "with g.name_scope()" blocks to generate structured + names. + + `unique_name` is used to generate structured names, separated by "/", + to help identify Operations when debugging a Graph. Operation names + are displayed in error messages reported by the TensorFlow runtime, + and in various visualization tools such as TensorBoard. + + Args: + name: The name for an `Operation`. + + Returns: + A string to be passed to `create_op()` that will be used + to name the operation being created. + """ + if self._name_stack[0]: + name = self._name_stack[0] + "/" + name + i = self._names_in_use.get(name, 0) + # Increment the number for "name". + self._names_in_use[name] = i + 1 + if i > 0: + base_name = name + # Make sure the composed name is not already used. + while name in self._names_in_use: + name = "%s_%d" % (base_name, i) + i += 1 + # Mark the composed name as used in case someone wants + # to call unique_name("name_1"). + self._names_in_use[name] = 1 + return name + + # TODO(mdevin): remove + def _plain_name(self, name): + """Return the fully scoped 'name'. + + Args: + name: a string. + + Returns: + 'name' scoped in the current name stack, without any uniquified + elements. + """ + if self._name_stack[1]: + return self._name_stack[1] + "/" + name + else: + return name + + def _set_default_device(self, dev): + """Set the default device properties. + + Args: + dev: string or Device. + """ + self._default_device = _device_string(dev) + + def get_default_device(self): + """Returns the default device. + + Returns: + A string. + """ + return self._default_device + + def _push_default_device_function(self, device_function): + """Pushes the given function onto the stack of device functions. + + See Graph.device for more details. + + Args: + device_function: The function to be pushed onto the stack of device + functions. + """ + self._device_function_stack.append(device_function) + + def _pop_default_device_function(self, device_function): + """Pops the given function from the stack of device functions. + + See Graph.device for more details. + + Args: + device_function: The function to be popped from the stack of device + functions. + + Raises: + ValueError: if the device_function to be popped is not top of the stack, + or if the stack is empty. + """ + if not self._device_function_stack: + raise ValueError("Tried to pop, but the device function stack is empty") + if self._device_function_stack[-1] is not device_function: + raise ValueError("Tried to pop device function, but it was not on top " + "of the stack") + + self._device_function_stack.pop() + + @contextlib.contextmanager + def device(self, device_name_or_function): + """Returns a context manager that specifies the default device to use. + + The `device_name_or_function` argument may either be a device name + string, a device function, or None: + + * If it is a device name string, all operations constructed in + this context will be assigned to the device with that name. + * If it is a function, it will be treated as function from + Operation objects to device name strings, and invoked each time + a new Operation is created. The Operation will be assigned to + the device with the returned name. + * If it is None, the default device will be cleared. + + For example: + + ```python + with g.device('/gpu:0'): + # All operations constructed in this context will be placed + # on GPU 0. + with g.device(None): + # All operations constructed in this context will have no + # assigned device. + + # Defines a function from `Operation` to device string. + def matmul_on_gpu(n): + if n.type == "MatMul": + return "/gpu:0" + else: + return "/cpu:0" + + with g.device(matmul_on_gpu): + # All operations of type "MatMul" constructed in this context + # will be placed on GPU 0; all other operations will be placed + # on CPU 0. + ``` + + Args: + device_name_or_function: The device name or function to use in + the context. + + Returns: + A context manager that specifies the default device to use for newly + created ops. + """ + if callable(device_name_or_function): + try: + self._push_default_device_function(device_name_or_function) + yield + finally: + self._pop_default_device_function(device_name_or_function) + else: + try: + old_dev = self.get_default_device() + self._set_default_device(_device_string(device_name_or_function)) + yield + finally: + self._set_default_device(old_dev) + + class _ControlDependenciesController(object): + """Context manager for `control_dependencies()`.""" + + def __init__(self, graph, control_inputs): + self._graph = graph + self._control_inputs = control_inputs + self._seen_nodes = set() + +# pylint: disable=protected-access + def __enter__(self): + self._graph._push_control_dependencies_controller(self) + + def __exit__(self, unused_type, unused_value, unused_traceback): + self._graph._pop_control_dependencies_controller(self) +# pylint: enable=protected-access + + @property + def control_inputs(self): + return self._control_inputs + + def add_op(self, op): + self._seen_nodes.add(op) + + def op_in_group(self, op): + return op in self._seen_nodes + + def _push_control_dependencies_controller(self, controller): + self._control_dependencies_stack.append(controller) + + def _pop_control_dependencies_controller(self, controller): + assert self._control_dependencies_stack[-1] is controller + self._control_dependencies_stack.pop() + + def _current_control_dependencies(self): + ret = set() + for controller in self._control_dependencies_stack: + for op in controller.control_inputs: + ret.add(op) + return ret + + def _control_dependencies_for_inputs(self, input_tensors): + """For an op that takes `input_tensors` as inputs, compute control inputs. + + The returned control dependencies should yield an execution that + is equivalent to adding all control inputs in + self._control_dependencies_stack to a newly created op. However, + this function attempts to prune the returned control dependencies + by observing that nodes created within the same `with + control_dependencies(...):` block may have data dependencies that make + the explicit approach redundant. + + Args: + input_tensors: The direct data dependencies for an op to be created. + + Returns: + A list of control inputs for the op to be created. + """ + ret = [] + input_ops = set([t.op for t in input_tensors]) + for controller in self._control_dependencies_stack: + # If any of the input_ops already depends on the inputs from controller, + # we say that the new op is dominated (by that input), and we therefore + # do not need to add control dependences for this controller's inputs. + dominated = False + for op in input_ops: + if controller.op_in_group(op): + dominated = True + break + if not dominated: + # Don't add a control input if we already have a data dependency on i. + # NOTE(mrry): We do not currently track transitive data dependencies, + # so we may add redundant control inputs. + ret.extend([c for c in controller.control_inputs if c not in input_ops]) + return ret + + def _record_op_seen_by_control_dependencies(self, op): + """Record that the given op depends on all registered control dependencies. + + Args: + op: An Operation. + """ + for controller in self._control_dependencies_stack: + controller.add_op(op) + + def control_dependencies(self, control_inputs): + """Returns a context manager that specifies control dependencies. + + Use with the `with` keyword to specify that all operations constructed + within the context should have control dependencies on + `control_inputs`. For example: + + ```python + with g.control_dependencies([a, b, c]): + # `d` and `e` will only run after `a`, `b`, and `c` have executed. + d = ... + e = ... + ``` + + Multiple calls to `control_dependencies()` can be nested, and in + that case a new `Operation` will have control dependencies on the union + of `control_inputs` from all active contexts. + + ```python + with g.control_dependencies([a, b]): + # Ops declared here run after `a` and `b`. + with g.control_dependencies([c, d]): + # Ops declared here run after `a`, `b`, `c`, and `d`. + ``` + + *N.B.* The control dependencies context applies *only* to ops that + are constructed within the context. Merely using an op or tensor + in the context does not add a control dependency. The following + example illustrates this point: + + ```python + # WRONG + def my_func(pred, tensor): + t = tf.matmul(tensor, tensor) + with tf.control_dependencies([pred]): + # The matmul op is created outside the context, so no control + # dependency will be added. + return t + + # RIGHT + def my_func(pred, tensor): + with tf.control_dependencies([pred]): + # The matmul op is created in the context, so a control dependency + # will be added. + return tf.matmul(tensor, tensor) + ``` + + Args: + control_inputs: A list of `Operation` or `Tensor` objects, which + must be executed or computed before running the operations + defined in the context. + + Returns: + A context manager that specifies control dependencies for all + operations constructed within the context. + + Raises: + TypeError: If `control_inputs` is not a list of `Operation` or + `Tensor` objects. + """ + # First convert the inputs to ops, and deduplicate them. + # NOTE(mrry): Other than deduplication, we do not currently track direct + # or indirect dependencies between control_inputs, which may result in + # redundant control inputs. + control_ops = [] + current = self._current_control_dependencies() + for c in control_inputs: + if isinstance(c, Tensor): + c = c.op + elif not isinstance(c, Operation): + raise TypeError("Control input must be Operation or Tensor: %s" % c) + if c not in current: + control_ops.append(c) + current.add(c) + return self._ControlDependenciesController(self, control_ops) + + # pylint: disable=g-doc-return-or-yield + @contextlib.contextmanager + def _kernel_label_map(self, op_to_kernel_label_map): + """EXPERIMENTAL: A context manager for setting kernel labels. + + This context manager can be used to select particular + implementations of kernels within the scope of the context. + + For example: + + with ops.Graph().as_default() as g: + f_1 = Foo() # Uses the default registered kernel for the Foo op. + with g.kernel_label_map({"Foo": "v_2"}): + f_2 = Foo() # Uses the registered kernel with label "v_2" + # for the Foo op. + with g.kernel_label_map({"Foo": "v_3"}): + f_3 = Foo() # Uses the registered kernel with label "v_3" + # for the Foo op. + with g.kernel_label_map({"Foo": ""}): + f_4 = Foo() # Uses the default registered kernel + # for the Foo op. + + Args: + op_to_kernel_label_map: A dictionary mapping op type strings to + kernel label strings. + + Returns: + A context manager that sets the kernel label to be used for one or more + ops created in that context. + + Raises: + TypeError: If op_to_kernel_label_map is not a dictionary mapping + strings to strings. + """ + if not isinstance(op_to_kernel_label_map, dict): + raise TypeError("op_to_kernel_label_map must be a dictionary mapping " + "strings to strings") + # The saved_labels dictionary stores any currently-set labels that + # will be overridden by this context manager. + saved_labels = {} + # Install the given label + for op_type, label in op_to_kernel_label_map.items(): + if not (isinstance(op_type, basestring) + and isinstance(label, basestring)): + raise TypeError("op_to_kernel_label_map must be a dictionary mapping " + "strings to strings") + try: + saved_labels[op_type] = self._op_to_kernel_label_map[op_type] + except KeyError: + pass + self._op_to_kernel_label_map[op_type] = label + try: + yield # The code within the context runs here. + finally: + # Remove the labels set for this context, and restore any saved labels. + for op_type, label in op_to_kernel_label_map.items(): + try: + self._op_to_kernel_label_map[op_type] = saved_labels[op_type] + except KeyError: + del self._op_to_kernel_label_map[op_type] + # pylint: enable=g-doc-return-or-yield + + # pylint: disable=g-doc-return-or-yield + @contextlib.contextmanager + def gradient_override_map(self, op_type_map): + """EXPERIMENTAL: A context manager for overriding gradient functions. + + This context manager can be used to override the gradient function + that will be used for ops within the scope of the context. + + For example: + + ```python + @tf.RegisterGradient("CustomSquare") + def _custom_square_grad(op, inputs): + # ... + + with tf.Graph().as_default() as g: + c = tf.constant(5.0) + s_1 = tf.square(c) # Uses the default gradient for tf.square. + with g.gradient_override_map({"Square": "CustomSquare"}): + s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the + # gradient of s_2. + ``` + + Args: + op_type_map: A dictionary mapping op type strings to alternative op + type strings. + + Returns: + A context manager that sets the alternative op type to be used for one + or more ops created in that context. + + Raises: + TypeError: If `op_type_map` is not a dictionary mapping strings to + strings. + """ + if not isinstance(op_type_map, dict): + raise TypeError("op_type_map must be a dictionary mapping " + "strings to strings") + # The saved_mappings dictionary stores any currently-set mappings that + # will be overridden by this context manager. + saved_mappings = {} + # Install the given label + for op_type, mapped_op_type in op_type_map.items(): + if not (isinstance(op_type, basestring) + and isinstance(mapped_op_type, basestring)): + raise TypeError("op_type_map must be a dictionary mapping " + "strings to strings") + try: + saved_mappings[op_type] = self._gradient_override_map[op_type] + except KeyError: + pass + self._gradient_override_map[op_type] = mapped_op_type + try: + yield # The code within the context runs here. + finally: + # Remove the labels set for this context, and restore any saved labels. + for op_type, mapped_op_type in op_type_map.items(): + try: + self._gradient_override_map[op_type] = saved_mappings[op_type] + except KeyError: + del self._gradient_override_map[op_type] + # pylint: enable=g-doc-return-or-yield + + +def device(dev): + """Wrapper for `Graph.device()` using the default graph. + + See [`Graph.name_scope()`](framework.md#Graph.name_scope) for more details. + + Args: + device_name_or_function: The device name or function to use in + the context. + + Returns: + A context manager that specifies the default device to use for newly + created ops. + """ + return get_default_graph().device(dev) + + +def name_scope(name): + """Wrapper for `Graph.name_scope()` using the default graph. + + See [`Graph.name_scope()`](framework.md#Graph.name_scope) for more details. + + Args: + name: A name for the scope. + + Returns: + A context manager that installs `name` as a new name scope in the + default graph. + """ + return get_default_graph().name_scope(name) + + +def control_dependencies(control_inputs): + """Wrapper for `Graph.control_dependencies()` using the default graph. + + See [`Graph.control_dependencies()`](framework.md#Graph.control_dependencies) + for more details. + + Args: + control_inputs: A list of `Operation` or `Tensor` objects, which + must be executed or computed before running the operations + defined in the context. + + Returns: + A context manager that specifies control dependencies for all + operations constructed within the context. + """ + return get_default_graph().control_dependencies(control_inputs) + + +class _DefaultStack(threading.local): + """A thread-local stack of objects for providing implicit defaults.""" + + def __init__(self): + super(_DefaultStack, self).__init__() + self.stack = [] + + def get_default(self): + return self.stack[-1] if len(self.stack) >= 1 else None + + def reset(self): + self.stack = [] + + @contextlib.contextmanager + def get_controller(self, default): + """A context manager for manipulating a default stack.""" + try: + self.stack.append(default) + yield default + finally: + assert self.stack[-1] is default + self.stack.pop() + + +_default_session_stack = _DefaultStack() + + +def default_session(session): + """Python "with" handler for defining a default session. + + This function provides a means of registering a session for handling + Tensor.eval() and Operation.run() calls. It is primarily intended for use + by session.Session, but can be used with any object that implements + the Session.run() interface. + + Use with the "with" keyword to specify that Tensor.eval() and Operation.run() + invocations within the scope of a block should be executed by a particular + session. + + The default session applies to the current thread only, so it is always + possible to inspect the call stack and determine the scope of a default + session. If you create a new thread, and wish to use the default session + in that thread, you must explicitly add a "with ops.default_session(sess):" + block in that thread's function. + + Example: + The following code examples are equivalent: + + # 1. Using the Session object directly: + sess = ... + c = tf.constant(5.0) + sess.run(c) + + # 2. Using default_session(): + sess = ... + with ops.default_session(sess): + c = tf.constant(5.0) + result = c.eval() + + # 3. Overriding default_session(): + sess = ... + with ops.default_session(sess): + c = tf.constant(5.0) + with ops.default_session(...): + c.eval(session=sess) + + Args: + session: The session to be installed as the default session. + + Returns: + A context manager for the default session. + """ + return _default_session_stack.get_controller(weakref.ref(session)) + + +def get_default_session(): + """Returns the default session for the current thread. + + The returned `Session` will be the innermost session on which a + `Session` or `Session.as_default()` context has been entered. + + *N.B.* The default session is a property of the current thread. If you + create a new thread, and wish to use the default session in that + thread, you must explicitly add a `with sess.as_default():` in that + thread's function. + + Returns: + The default `Session` being used in the current thread. + """ + ref = _default_session_stack.get_default() + if ref is None: + # No default session has been registered. + return None + else: + # De-reference ref. + ret = ref() + if ret is None: + # This should never happen with the current session implementations. + raise RuntimeError("Default session has been garbage collected.") + return ret + + +def _eval_using_default_session(tensors, feed_dict, graph, session=None): + """Uses the default session to evaluate one or more tensors. + + Args: + tensors: A single Tensor, or a list of Tensor objects. + feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, + numpy ndarrays, TensorProtos, or strings. + graph: The graph in which the tensors are defined. + session: (Optional) A different session to use to evaluate "tensors". + + Returns: + Either a single numpy ndarray if "tensors" is a single tensor; or a list + of numpy ndarrays that each correspond to the respective element in + "tensors". + + Raises: + ValueError: If no default session is available; the default session + does not have "graph" as its graph; or if "session" is specified, + and it does not have "graph" as its graph. + """ + if session is None: + session = get_default_session() + if session is None: + raise ValueError("Cannot evaluate tensor using eval(): No default " + "session is registered. Use 'with " + "DefaultSession(sess)' or pass an explicit session to " + "eval(session=sess)") + if session.graph is not graph: + raise ValueError("Cannot use the default session to evaluate tensor: " + "the tensor's graph is different from the session's " + "graph. Pass an explicit session to " + "eval(session=sess).") + else: + if session.graph is not graph: + raise ValueError("Cannot use the given session to evaluate tensor: " + "the tensor's graph is different from the session's " + "graph.") + return session.run(tensors, feed_dict) + + +def _run_using_default_session(operation, feed_dict, graph, session=None): + """Uses the default session to run "operation". + + Args: + operation: The Operation to be run. + feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, + numpy ndarrays, TensorProtos, or strings. + graph: The graph in which "operation" is defined. + session: (Optional) A different session to use to run "operation". + + Raises: + ValueError: If no default session is available; the default session + does not have "graph" as its graph; or if "session" is specified, + and it does not have "graph" as its graph. + """ + if session is None: + session = get_default_session() + if session is None: + raise ValueError("Cannot execute operation using Run(): No default " + "session is registered. Use 'with " + "default_session(sess)' or pass an explicit session to " + "Run(session=sess)") + if session.graph is not graph: + raise ValueError("Cannot use the default session to execute operation: " + "the operation's graph is different from the " + "session's graph. Pass an explicit session to " + "Run(session=sess).") + else: + if session.graph is not graph: + raise ValueError("Cannot use the given session to execute operation: " + "the operation's graph is different from the session's " + "graph.") + session.run(operation, feed_dict) + + +class _DefaultGraphStack(_DefaultStack): + """A thread-local stack of objects for providing an implicit default graph.""" + + def __init__(self): + super(_DefaultGraphStack, self).__init__() + self._global_default_graph = None + + def get_default(self): + """Override that returns a global default if the stack is empty.""" + ret = super(_DefaultGraphStack, self).get_default() + if ret is None: + ret = self._GetGlobalDefaultGraph() + return ret + + def _GetGlobalDefaultGraph(self): + if self._global_default_graph is None: + # TODO(mrry): Perhaps log that the default graph is being used, or set + # provide some other feedback to prevent confusion when a mixture of + # the global default graph and an explicit graph are combined in the + # same process. + self._global_default_graph = Graph() + return self._global_default_graph + + def reset(self): + super(_DefaultGraphStack, self).reset() + self._global_default_graph = None + +_default_graph_stack = _DefaultGraphStack() + + +def reset_default_graph(): + """Clears the default graph stack and resets the global default graph. + + *N.B.* The default graph is a property of the current thread. This + function applies only to the current thread. + """ + _default_graph_stack.reset() + + +def get_default_graph(): + """Returns the default graph for the current thread. + + The returned graph will be the innermost graph on which a + `Graph.as_default()` context has been entered, or a global default + graph if none has been explicitly created. + + *N.B.* The default graph is a property of the current thread. If you + create a new thread, and wish to use the default graph in that + thread, you must explicitly add a `with g.as_default():` in that + thread's function. + + Returns: + The default `Graph` being used in the current thread. + """ + return _default_graph_stack.get_default() + + +def _get_graph_from_inputs(op_input_list, graph=None): + """Returns the appropriate graph to use for the given inputs. + + This library method provides a consistent algorithm for choosing the graph + in which an Operation should be constructed: + + 1. If the "graph" is specified explicitly, we validate that all of the inputs + in "op_input_list" are compatible with that graph. + 2. Otherwise, we attempt to select a graph from the first Operation- + or Tensor-valued input in "op_input_list", and validate that all other + such inputs are in the same graph. + 3. If the graph was not specified and it could not be inferred from + "op_input_list", we attempt to use the default graph. + + Args: + op_input_list: A list of inputs to an operation, which may include Tensor + and Operation objects. + graph: (Optional) The explicit graph to use. + + Raises: + TypeError: If op_input_list is not a list or tuple, or if graph is not a + Graph. + ValueError: If a graph is explicitly passed and not all inputs are from it, + or if the inputs are from multiple graphs, or we could not find a graph + and there was no default graph. + + Returns: + The appropriate graph to use for the given inputs. + """ + if not isinstance(op_input_list, (list, tuple)): + raise TypeError("The op_input_list must be a list or tuple") + + # 1. If the graph is specified explicitly, we validate that all of the inputs + # are compatible with that graph. + if graph is not None: + if not isinstance(graph, Graph): + raise TypeError("Input graph needs to be a Graph: %s" % graph) + for op_input in op_input_list: + if isinstance(op_input, Operation): + if op_input.graph is not graph: + raise ValueError("Operation %s is not from the passed-in graph" + % op_input) + elif isinstance(op_input, Tensor): + if op_input.graph is not graph: + raise ValueError("Tensor %s is not from the passed-in graph" + % op_input) + return graph + + # 2. Otherwise, we attempt to select a graph from one of the Operation- + # or Tensor-valued inputs. + original_input = None + for op_input in op_input_list: + if isinstance(op_input, (Operation, Tensor)): + if original_input is None: + original_input = op_input + else: + assert_same_graph([original_input, op_input]) + if original_input is not None: + return original_input.graph + + # 3. If all else fails, we use the default graph, which is always there. + return get_default_graph() + + +class GraphKeys(object): + """Standard names to use for graph collections. + + The standard library uses various well-known names to collect and + retrieve values associated with a graph. For example, the + `tf.Optimizer` subclasses default to optimizing the variables + collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is + specified, but it is also possible to pass an explicit list of + variables. + + The following standard keys are defined: + + * `VARIABLES`: the `Variable` objects that comprise a model, and + must be saved and restored together. See + [`tf.all_variables()`](state_ops.md#all_variables) for more details. + * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will + be trained by an optimizer. See + [`tf.trainable_variables()`](state_ops.md#trainable_variables) + for more details. + * `SUMMARIES`: the summary `Tensor` objects that have been created + in the graph. See [`tf.merge_all_summaries()`](train.md#merge_all_summaries) + for more details. + * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to + produce input for a computation. See + [`tf.start_queue_runners()`](train.md#start_queue_runners) for more details. + """ + + # Key to collect variables.Variable objects that must be saved and restored + # by the model. + VARIABLES = "variables" + # Key to collect variables.Variable objects that will be trained by the + # optimizers. + TRAINABLE_VARIABLES = "trainable_variables" + # Key to collect summaries. + SUMMARIES = "summaries" + # Key to collect QueueRunners. + QUEUE_RUNNERS = "queue_runners" + # Key to collect table initializers. + TABLE_INITIALIZERS = "table_initializer" + + +def add_to_collection(name, value): + """Wrapper for `Graph.add_to_collection()` using the default graph. + + See [`Graph.add_to_collection()`](framework.md#Graph.add_to_collection) + for more details. + + Args: + name: The key for the collection. For example, the `GraphKeys` class + contains many standard names for collections. + value: The value to add to the collection. + """ + get_default_graph().add_to_collection(name, value) + + +def get_collection(key, scope=None): + """Wrapper for `Graph.get_collection()` using the default graph. + + See [`Graph.get_collection()`](framework.md#Graph.get_collection) + for more details. + + Args: + key: The key for the collection. For example, the `GraphKeys` class + contains many standard names for collections. + scope: (Optional.) If supplied, the resulting list is filtered to include + only items whose name begins with this string. + + Returns: + The list of values in the collection with the given `name`, or + an empty list if no value has been added to that collection. The + list contains the values in the order under which they were + collected. + """ + return get_default_graph().get_collection(key, scope) + + +# pylint: disable=g-doc-return-or-yield +@contextlib.contextmanager +def op_scope(values, name, default_name): + """Returns a context manager for use when defining a Python op. + + This context manager validates that the given `values` are from the + same graph, ensures that that graph is the default graph, and pushes a + name scope. + + For example, to define a new Python op called `my_op`: + + ```python + def my_op(a, b, c, name=None): + with tf.op_scope([a, b, c], name, "MyOp") as scope: + a = tf.convert_to_tensor(a, name="a") + b = tf.convert_to_tensor(b, name="b") + c = tf.convert_to_tensor(c, name="c") + # Define some computation that uses `a`, `b`, and `c`. + return foo_op(..., name=scope) + ``` + + Args: + values: The list of `Tensor` arguments that are passed to the op function. + name: The name argument that is passed to the op function. + default_name: The default name to use if the `name` argument is `None`. + + Returns: + A context manager for use in defining a Python op. + """ + g = _get_graph_from_inputs(values) + n = default_name if name is None else name + with g.as_default(), g.name_scope(n) as scope: + yield scope +# pylint: enable=g-doc-return-or-yield diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py new file mode 100644 index 0000000000..a406c5e56e --- /dev/null +++ b/tensorflow/python/framework/ops_test.py @@ -0,0 +1,825 @@ +"""Tests for tensorflow.python.framework.ops.""" +import tensorflow.python.platform + +from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_kernel_label_op +from tensorflow.python.framework import test_util +from tensorflow.python.framework import types +from tensorflow.python.ops import common_shapes +from tensorflow.python.platform import googletest + + +class TensorTest(test_util.TensorFlowTestCase): + + def testShape(self): + op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), + [], [types.float32]) + t = op.outputs[0] + self.assertEquals(tensor_shape.unknown_shape(), t.get_shape()) + t.set_shape([1, 2, 3]) + self.assertEquals([1, 2, 3], t.get_shape()) + + +class NodeDefConstructorTest(test_util.TensorFlowTestCase): + + def testNoArgs(self): + nodedef = ops._NodeDef("noop", "bar") + self.assertProtoEquals("op: 'noop' name: 'bar'", nodedef) + + def testArgs(self): + nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*") + self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'", + nodedef) + nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j")) + self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef) + + +# NOTE(mrry): Dummy shape registrations for ops used in the tests. +ops.RegisterShape("a")(None) +ops.RegisterShape("b")(None) +ops.RegisterShape("c")(None) +ops.RegisterShape("add")(None) +ops.RegisterShape("an_op")(None) +ops.RegisterShape("const")(None) +ops.RegisterShape("copy")(None) +ops.RegisterShape("foo")(None) +ops.RegisterShape("identity")(None) +ops.RegisterShape("mul")(None) +ops.RegisterShape("nonrefop")(None) +ops.RegisterShape("noop")(None) +ops.RegisterShape("refop")(None) + + +def _apply_op(g, *args, **kwargs): + op = g.create_op(*args, **kwargs) + if len(op.outputs) == 1: + return op.outputs[0] + else: + return op.outputs + + +class OperationTest(test_util.TensorFlowTestCase): + + def testNoInputs(self): + op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), + [], + [types.float32, types.string]) + self.assertEquals(2, len(op.values())) + self.assertEquals(0, len(op.inputs)) + self.assertEquals("myop", op.name) + + float_t, label_str_t = op.values() + self.assertEquals(types.float32, float_t.dtype) + self.assertEquals(op, float_t.op) + self.assertEquals(0, float_t._value_index) + self.assertEquals(0, len(float_t._consumers)) + self.assertEquals("myop", float_t._as_node_def_input()) + + self.assertEquals(types.string, label_str_t.dtype) + self.assertEquals(op, label_str_t.op) + self.assertEquals(1, label_str_t._value_index) + self.assertEquals(0, len(label_str_t._consumers)) + self.assertEquals("myop:1", label_str_t._as_node_def_input()) + + self.assertProtoEquals("op:'noop' name:'myop'", op.node_def) + + def testNoOutputs(self): + g = ops.Graph() + op1 = ops.Operation( + ops._NodeDef("noop", "myop1"), g, [], [types.float32]) + float_t, = op1.values() + op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [float_t], []) + self.assertEquals(0, len(op2.values())) + self.assertEquals(1, len(op2.inputs)) + self.assertIs(float_t, op2.inputs[0]) + + self.assertEquals(1, len(float_t._consumers)) + self.assertEquals(op2, float_t._consumers[0]) + + self.assertProtoEquals("op:'noop' name:'myop1'", op1.node_def) + self.assertProtoEquals("op:'reop' name:'myop2' input:'myop1'", + op2.node_def) + + def testInputsAndOutputs(self): + g = ops.Graph() + op1 = ops.Operation( + ops._NodeDef("noop", "myop1"), g, [], [types.float32]) + self.assertEquals(1, len(op1.values())) + float1_t, = op1.values() + + op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, + [], [types.float32, types.string]) + self.assertEquals(2, len(op2.values())) + float2_t, label2_str_t = op2.values() + + # Note that we consume label2_str_t twice here. + op3 = ops.Operation(ops._NodeDef("add", "myop3"), g, + [float1_t, label2_str_t, label2_str_t], + [types.float32, types.int32]) + self.assertEquals(2, len(op3.values())) + + self.assertEquals(1, len(float1_t._consumers)) + self.assertEquals(op3, float1_t._consumers[0]) + + self.assertEquals(0, len(float2_t._consumers)) + + self.assertEquals(2, len(label2_str_t._consumers)) + self.assertEquals(op3, label2_str_t._consumers[0]) + self.assertEquals(op3, label2_str_t._consumers[1]) + + self.assertProtoEquals(""" + op:'add' name:'myop3' + input:'myop1' input:'myop2:1' input:'myop2:1' + """, op3.node_def) + + def testDeviceObject(self): + op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], []) + op._set_device("/job:goo/device:GPU:0") + self.assertProtoEquals( + "op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ", + op.node_def) + op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], []) + op._set_device(pydev.Device(job="muu", device_type="CPU", device_index=0)) + self.assertProtoEquals( + "op:'noop' name:'op2' device:'/job:muu/device:CPU:0'", + op.node_def) + + def testReferenceInput(self): + g = ops.Graph() + op1 = ops.Operation(ops._NodeDef("noop", "op1"), g, [], + [types.float32_ref, types.float32]) + self.assertProtoEquals("op:'noop' name:'op1'", + op1.node_def) + ref_t, nonref_t = op1.values() + # NOTE(mrry): Must specify input_types to preserve ref-typed input. + op2 = ops.Operation( + ops._NodeDef("refop", "op2"), g, [ref_t, nonref_t], [], + input_types=[types.float32_ref, types.float32]) + self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'", + op2.node_def) + op3 = ops.Operation( + ops._NodeDef("nonrefop", "op3"), g, [ref_t, nonref_t], []) + self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'", + op3.node_def) + + def testInvalidNames(self): + g = ops.Graph() + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", ""), g) + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", "_invalid"), g) + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", "-invalid"), g) + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", "/invalid"), g) + + def testShapeFunctionAbsence(self): + def _test(): + pass + g = ops.Graph() + with self.assertRaises(RuntimeError): + g.create_op("shapeless_op", [], [types.float32]) + + def testNoShapeFunction(self): + g = ops.Graph() + op = ops.Operation(ops._NodeDef("op", "an_op"), g, + output_types = [types.float32]) + self.assertEquals(tensor_shape.unknown_shape(), + _apply_op(g, "an_op", [], [types.float32]).get_shape()) + +class CreateOpTest(test_util.TensorFlowTestCase): + + def testNodeDefArgs(self): + g = ops.Graph() + op1 = g.create_op("const", [], [types.float32], None, name="myop1") + with g.device("/device:GPU"): + op2 = g.create_op("add", + [], + [types.float32, types.string], None, + name="myop2") + op3 = g.create_op( + "foo", + [op1.values()[0], op2.values()[1], op2.values()[0]], + [types.float32, types.int32], None, + name="myop3") + self.assertEquals(None, op1.device) + self.assertEquals("/device:GPU", op2.device) + self.assertEquals(None, op3.device) + self.assertProtoEquals("name:'myop1' op:'const'", op1.node_def) + self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'", + op2.node_def) + self.assertProtoEquals( + "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'", + op3.node_def) + + def testReferenceInput(self): + g = ops.Graph() + op1 = g.create_op("noop", [], + [types.float32_ref, types.float32], name="op1") + self.assertProtoEquals("op:'noop' name:'op1'", op1.node_def) + ref_t, nonref_t = op1.values() + # NOTE(mrry): Must specify input_types to preserve ref-typed input. + op2 = g.create_op("refop", [ref_t, nonref_t], [], + input_types=[types.float32_ref, types.float32], + name="op2") + self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'", + op2.node_def) + op3 = g.create_op("nonrefop", [ref_t, nonref_t], [], name="op3") + self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'", + op3.node_def) + + def testFinalized(self): + g = ops.Graph() + g.finalize() + with self.assertRaises(RuntimeError): + g.create_op("const", [], [types.float32], None, name="myop1") + + +class ApplyOpTest(test_util.TensorFlowTestCase): + + def testNodeDefArgs(self): + g = ops.Graph() + t1 = _apply_op(g, "const", [], [types.float32], name="myop1") + with g.device("/device:GPU"): + t2 = _apply_op(g, "add", + [], + [types.float32, types.string], + name="myop2") + t3 = _apply_op(g, "foo", [t1, t2[1], t2[0]], + [types.float32, types.int32], name="myop3") + self.assertTrue(isinstance(t1, ops.Tensor)) + self.assertTrue(isinstance(t2, list)) + self.assertTrue(isinstance(t3, list)) + self.assertTrue(isinstance(t3[0], ops.Tensor)) + self.assertEquals("myop1", t1._as_node_def_input()) + self.assertEquals("myop2", t2[0]._as_node_def_input()) + self.assertEquals("myop2:1", t2[1]._as_node_def_input()) + self.assertEquals("myop3", t3[0]._as_node_def_input()) + # Validate that we got the right ops as well + self.assertProtoEquals("name:'myop1' op:'const'", t1.op.node_def) + self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'", + t2[0].op.node_def) + self.assertProtoEquals( + "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'foo'", + t3[0].op.node_def) + + def testReferenceInput(self): + g = ops.Graph() + ref_t, nonref_t = _apply_op( + g, "noop", [], [types.float32_ref, types.float32], name="op1") + self.assertProtoEquals("op:'noop' name:'op1'", ref_t.op.node_def) + # NOTE(mrry): Must specify input_types to preserve ref-typed input. + out_2 = _apply_op(g, "refop", [ref_t, nonref_t], [types.int32], + input_types=[types.float32_ref, types.float32], + name="op2") + self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'", + out_2.op.node_def) + out_3 = _apply_op(g, "nonrefop", [ref_t, nonref_t], [types.int32], + name="op3") + self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'", + out_3.op.node_def) + + +class NameStackTest(test_util.TensorFlowTestCase): + + def testBasics(self): + g = ops.Graph() + self.assertEquals("foo", g.unique_name("foo")) + self.assertEquals("foo_1", g.unique_name("foo")) + self.assertEquals("foo_2", g.unique_name("foo")) + self.assertEquals("foo_1_1", g.unique_name("foo_1")) + self.assertEquals("foo_1_2", g.unique_name("foo_1")) + self.assertEquals("foo_1_2_1", g.unique_name("foo_1_2")) + with g.name_scope("bar"): + self.assertEquals("bar/foo", g.unique_name("foo")) + self.assertEquals("bar/foo_1", g.unique_name("foo")) + with g.name_scope(None): + self.assertEquals("foo_3", g.unique_name("foo")) + with g.name_scope("baz"): + self.assertEquals("bar/baz/foo", g.unique_name("foo")) + self.assertEquals("bar/baz/foo_1", g.unique_name("foo")) + with g.name_scope("baz"): + self.assertEquals("bar/baz_1/foo", g.unique_name("foo")) + self.assertEquals("bar/baz_1/foo_1", g.unique_name("foo")) + with g.name_scope("quux"): + self.assertEquals("quux/foo", g.unique_name("foo")) + with g.name_scope("bar"): + with g.name_scope("baz"): + self.assertEquals("bar_1/baz/foo", g.unique_name("foo")) + self.assertEquals("foo_4", g.unique_name("foo")) + self.assertEquals("bar_2", g.unique_name("bar")) + + def testOutOfOrderUniqueName(self): + g = ops.Graph() + self.assertEquals("foo_2", g.unique_name("foo_2")) + self.assertEquals("foo", g.unique_name("foo")) + self.assertEquals("foo_1", g.unique_name("foo")) + self.assertEquals("foo_3", g.unique_name("foo")) + + +class NameTest(test_util.TensorFlowTestCase): + + def testGenerateName(self): + g = ops.Graph() + op0 = g.create_op("const", [], [types.float32, types.float32]) + self.assertEquals("const", op0.name) + self.assertEquals("const:0", op0.outputs[0].name) + self.assertEquals("const:1", op0.outputs[1].name) + + op1 = g.create_op("const", [], [types.float32]) + self.assertEquals("const_1", op1.name) + self.assertEquals("const_1:0", op1.outputs[0].name) + + op2 = g.create_op("const", [], [types.float32], name="my_op") + self.assertEquals("my_op", op2.name) + self.assertEquals("my_op:0", op2.outputs[0].name) + + def testname_scope(self): + g = ops.Graph() + + with g.name_scope("foo") as foo: + self.assertEquals(foo, "foo/") + with g.name_scope("foo2") as foo2: + self.assertEquals(foo2, "foo/foo2/") + with g.name_scope(None) as empty1: + self.assertEquals(empty1, "") + with g.name_scope("foo3") as foo3: + self.assertEquals(foo3, "foo3/") + with g.name_scope("") as empty2: + self.assertEquals(empty2, "") + + self.assertEquals("const", + g.create_op("const", [], [types.float32]).name) + with g.name_scope("bar") as scope: + self.assertEquals("bar/const", + g.create_op("const", [], [types.float32]).name) + self.assertEquals("bar/const_1", + g.create_op("const", [], [types.float32]).name) + # If you use the value from "with .. as", that values is used as-is. + self.assertEquals( + "bar", + g.create_op("const", [], [types.float32], name=scope).name) + with g.name_scope("baz") as scope: + with g.name_scope("quux"): + self.assertEquals("baz/quux/const", + g.create_op("const", [], [types.float32]).name) + # If you use the value from the enclosing "with .. as", nothing is pushed. + with g.name_scope(scope): + self.assertEquals("baz/const", + g.create_op("const", [], [types.float32]).name) + self.assertEquals("baz", + g.create_op("const", [], [types.float32], + name=scope).name) + self.assertEquals("trailing", + g.create_op("const", [], [types.float32], + name="trailing/").name) + with g.name_scope("bar"): + self.assertEquals("bar_1/const", + g.create_op("const", [], [types.float32]).name) + with g.name_scope("bar/"): + self.assertEquals("bar/const_2", + g.create_op("const", [], [types.float32]).name) + + +class DeviceTest(test_util.TensorFlowTestCase): + + def testNoDevice(self): + g = ops.Graph() + op = g.create_op("an_op", [], [types.float32]) + self.assertEqual(None, op.device) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" } + """, gd) + + def testDevicePartialString(self): + g = ops.Graph() + with g.device("/job:worker/replica:2"): + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" device: "/job:worker/replica:2" } + """, gd) + + def testDeviceFull(self): + g = ops.Graph() + with g.device(pydev.Device(job="worker", replica=2, task=0, + device_type="CPU", + device_index=3)): + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/job:worker/replica:2/task:0/device:CPU:3" } + """, gd) + + def testNesting(self): + g = ops.Graph() + with g.device("/job:worker/replica:2"): + g.create_op("an_op", [], [types.float32]) + with g.device("/job:worker/replica:3/task:0"): + g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/job:worker/replica:2" } + node { name: "an_op_1" op: "an_op" + device: "/job:worker/replica:3/task:0" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/replica:2" } + """, gd) + + def testNestingString(self): + g = ops.Graph() + with g.device("/job:worker/replica:2"): + g.create_op("an_op", [], [types.float32]) + with g.device("/job:worker/replica:3/task:0"): + g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/job:worker/replica:2" } + node { name: "an_op_1" op: "an_op" + device: "/job:worker/replica:3/task:0" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/replica:2" } + """, gd) + + def testNestingOverrideGpuCpu(self): + g = ops.Graph() + with g.device("/job:worker/replica:2/device:CPU:1"): + g.create_op("an_op", [], [types.float32]) + with g.device("/job:worker/replica:2/device:GPU:2"): + g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/job:worker/replica:2/device:CPU:1" } + node { name: "an_op_1" op: "an_op" + device: "/job:worker/replica:2/device:GPU:2" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/replica:2/device:CPU:1" } + """, gd) + + def testNestingWithMergeDeviceFunction(self): + g = ops.Graph() + + with g.device(pydev.merge_device("/device:GPU:0")): + g.create_op("an_op", [], [types.float32]) + with g.device(pydev.merge_device("/job:worker")): + g.create_op("an_op", [], [types.float32]) + with g.device(pydev.merge_device("/device:CPU:0")): + g.create_op("an_op", [], [types.float32]) + with g.device(pydev.merge_device("/job:ps")): + g.create_op("an_op", [], [types.float32]) + with g.device(pydev.merge_device(None)): + g.create_op("an_op", [], [types.float32]) + + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/device:GPU:0" } + node { name: "an_op_1" op: "an_op" + device: "/job:worker/device:GPU:0" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/device:CPU:0" } + node { name: "an_op_3" op: "an_op" + device: "/job:ps/device:CPU:0" } + node { name: "an_op_4" op: "an_op" + device: "/job:ps/device:CPU:0" } + """, gd) + + def testNoneClearsDefault(self): + g = ops.Graph() + with g.device("/job:worker/replica:2/device:CPU:1"): + g.create_op("an_op", [], [types.float32]) + with g.device(None): + g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [types.float32]) + gd = g.as_graph_def() + self.assertProtoEquals(""" + node { name: "an_op" op: "an_op" + device: "/job:worker/replica:2/device:CPU:1" } + node { name: "an_op_1" op: "an_op" } + node { name: "an_op_2" op: "an_op" + device: "/job:worker/replica:2/device:CPU:1" } + """, gd) + + +class ObjectWithName(object): + + def __init__(self, name): + self._name = name + + @property + def name(self): + return self._name + + +class CollectionTest(test_util.TensorFlowTestCase): + + def testadd_to_collection(self): + g = ops.Graph() + g.add_to_collection("key", 12) + g.add_to_collection("other", "foo") + g.add_to_collection("key", 34) + + # Note that only blank1 is returned. + g.add_to_collection("blah", 27) + blank1 = ObjectWithName("prefix/foo") + g.add_to_collection("blah", blank1) + blank2 = ObjectWithName("junk/foo") + g.add_to_collection("blah", blank2) + + self.assertEquals(["foo"], g.get_collection("other")) + self.assertEquals([12, 34], g.get_collection("key")) + self.assertEquals([], g.get_collection("nothing")) + self.assertEquals([27, blank1, blank2], g.get_collection("blah")) + self.assertEquals([blank1], g.get_collection("blah", "prefix")) + + def testDefaulGraph(self): + with ops.Graph().as_default(): + ops.add_to_collection("key", 90) + ops.add_to_collection("key", 100) + # Collections are ordered. + self.assertEquals([90, 100], ops.get_collection("key")) + + +def an_op(g): + return _apply_op(g, "an_op", [], [types.float32]) + + +ops.NoGradient("an_op") + + +def copy_op(x): + return _apply_op(x.graph, "copy", [x], [x.dtype]) + + +@ops.RegisterGradient("copy") +def _CopyGrad(op, x_grad): + _ = op + return x_grad + + +@ops.RegisterGradient("copy_override") +def _CopyOverrideGrad(op, x_grad): + _ = op + return x_grad + + +class RegistrationTest(test_util.TensorFlowTestCase): + + def testRegisterGradients(self): + g = ops.Graph() + x = an_op(g) + y = copy_op(x) + fn = ops.get_gradient_function(y.op) + self.assertEquals(_CopyGrad, fn) + + def testOverrideGradients(self): + g = ops.Graph() + x = an_op(g) + with g.gradient_override_map({"copy": "copy_override"}): + y = copy_op(x) + fn = ops.get_gradient_function(y.op) + self.assertEquals(_CopyOverrideGrad, fn) + + def testNonExistentOverride(self): + g = ops.Graph() + x = an_op(g) + with g.gradient_override_map({"copy": "unknown_override"}): + y = copy_op(x) + with self.assertRaisesRegexp(LookupError, "unknown_override"): + fn = ops.get_gradient_function(y.op) + + +class ComparisonTest(test_util.TensorFlowTestCase): + + def testMembershipAllowed(self): + g = ops.Graph() + t1 = _apply_op(g, "const", [], [types.float32], name="myop1") + t2 = _apply_op(g, "const", [], [types.float32], name="myop2") + self.assertTrue(isinstance(t1, ops.Tensor)) + self.assertTrue(isinstance(t2, ops.Tensor)) + self.assertTrue(t1 in [t1]) + self.assertTrue(t1 not in [t2]) + + +class ControlDependenciesTest(test_util.TensorFlowTestCase): + + def testBasic(self): + g = ops.Graph() + a = _apply_op(g, "const", [], [types.float32]) + b = _apply_op(g, "const", [], [types.float32]) + with g.control_dependencies([a]): + c = _apply_op(g, "const", [], [types.float32]) + d = _apply_op(g, "identity", [b], [types.float32]) + e = _apply_op(g, "identity", [c], [types.float32]) + + self.assertEqual(c.op.control_inputs, [a.op]) + self.assertEqual(d.op.control_inputs, [a.op]) + # e should be dominated by c. + self.assertEqual(e.op.control_inputs, []) + + def testNested(self): + g = ops.Graph() + a_1 = _apply_op(g, "const", [], [types.float32]) + a_2 = _apply_op(g, "const", [], [types.float32]) + a_3 = _apply_op(g, "const", [], [types.float32]) + a_4 = _apply_op(g, "const", [], [types.float32]) + + with g.control_dependencies([a_1, a_2, a_3, a_4]): + b_1 = _apply_op(g, "const", [], [types.float32]) + + with g.control_dependencies([a_1]): + with g.control_dependencies([a_2]): + with g.control_dependencies([a_3]): + with g.control_dependencies([a_4]): + b_2 = _apply_op(g, "const", [], [types.float32]) + + self.assertItemsEqual( + [a_1.op, a_2.op, a_3.op, a_4.op], b_1.op.control_inputs) + self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs) + + def testComplex(self): + g = ops.Graph() + + # Usage pattern: + # * Nodes a_i are constants defined at the outermost scope, and are used + # as control inputs for the ith nested scope. + # * Nodes b_i are defined as Mul(a_3, a_4) at each scope. + # * Nodes c_i are defined as Mul(a_1, b_1) at each scope. + # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. + # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. + + a_1 = _apply_op(g, "const", [], [types.float32]) + a_2 = _apply_op(g, "const", [], [types.float32]) + a_3 = _apply_op(g, "const", [], [types.float32]) + a_4 = _apply_op(g, "const", [], [types.float32]) + + with g.control_dependencies([a_1]): + b_1 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) + c_1 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) + d_1 = _apply_op(g, "mul", [b_1, c_1], [types.float32]) + e_1 = _apply_op(g, "const", [], [types.float32]) + with g.control_dependencies([a_2]): + b_2 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) + c_2 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) + d_2 = _apply_op(g, "mul", [b_2, c_2], [types.float32]) + e_2 = _apply_op(g, "mul", [e_1, e_1], [types.float32]) + with g.control_dependencies([a_3]): + b_3 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) + c_3 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) + d_3 = _apply_op(g, "mul", [b_3, c_3], [types.float32]) + e_3 = _apply_op(g, "mul", [e_2, e_2], [types.float32]) + with g.control_dependencies([a_4]): + b_4 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) + c_4 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) + d_4 = _apply_op(g, "mul", [b_4, c_4], [types.float32]) + e_4 = _apply_op(g, "mul", [e_3, e_3], [types.float32]) + + self.assertItemsEqual([a_1.op], b_1.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs) + + self.assertItemsEqual([], c_1.op.control_inputs) + self.assertItemsEqual([a_2.op], c_2.op.control_inputs) + self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs) + self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs) + + self.assertItemsEqual([], d_1.op.control_inputs) + self.assertItemsEqual([], d_2.op.control_inputs) + self.assertItemsEqual([], d_3.op.control_inputs) + self.assertItemsEqual([], d_4.op.control_inputs) + + self.assertItemsEqual([a_1.op], e_1.op.control_inputs) + self.assertItemsEqual([a_2.op], e_2.op.control_inputs) + self.assertItemsEqual([a_3.op], e_3.op.control_inputs) + self.assertItemsEqual([a_4.op], e_4.op.control_inputs) + + def testRepeatedDependency(self): + g = ops.Graph() + a = g.create_op("foo", [], [types.float32, types.float32]) + a_0, a_1 = a.outputs + with g.control_dependencies([a_0]): + b = _apply_op(g, "const", [], [types.float32]) + with g.control_dependencies([a_1]): + c = _apply_op(g, "const", [], [types.float32]) + + self.assertEqual(b.op.control_inputs, [a]) + self.assertEqual(c.op.control_inputs, [a]) + + def testNoControlDependencyWithDataDependency(self): + g = ops.Graph() + a = _apply_op(g, "const", [], [types.float32]) + with g.control_dependencies([a]): + b = _apply_op(g, "identity", [a], [types.float32]) + + self.assertEqual(b.op.control_inputs, []) + + +class GraphTest(test_util.TensorFlowTestCase): + + def setUp(self): + ops.reset_default_graph() + + def _AssertDefault(self, expected): + self.assertIs(expected, ops.get_default_graph()) + + def testGraphContextManager(self): + g0 = ops.Graph() + with g0.as_default() as g1: + self.assertIs(g0, g1) + + def testDefaultGraph(self): + orig = ops.get_default_graph() + self._AssertDefault(orig) + g0 = ops.Graph() + self._AssertDefault(orig) + context_manager_0 = g0.as_default() + self._AssertDefault(orig) + with context_manager_0 as g0: + self._AssertDefault(g0) + with ops.Graph().as_default() as g1: + self._AssertDefault(g1) + self._AssertDefault(g0) + self._AssertDefault(orig) + + def testAsGraphElementConversions(self): + class ConvertibleObj(object): + + def _as_graph_element(self): + return "const:0" + + class NonConvertibleObj(object): + + pass + + g = ops.Graph() + a = _apply_op(g, "const", [], [types.float32]) + self.assertEqual(a, g.as_graph_element(ConvertibleObj())) + with self.assertRaises(TypeError): + g.as_graph_element(NonConvertibleObj()) + + def testAssertSameGraph(self): + g0 = ops.Graph() + a = g0.create_op("a", [], [types.float32]) + b = g0.create_op("b", [], [types.float32]) + ops.assert_same_graph([a, b]) + ops.assert_same_graph([a, b], g0) + g1 = ops.Graph() + c = g1.create_op("c", [], [types.float32]) + self.assertRaises(ValueError, ops.assert_same_graph, [a, b, c]) + self.assertRaises(ValueError, ops.assert_same_graph, [c], g0) + self.assertRaises(ValueError, ops.assert_same_graph, [a], g1) + + sparse = ops.SparseTensor( + _apply_op(g0, "const", [], [types.int64]), + _apply_op(g0, "const", [], [types.float32]), + _apply_op(g0, "const", [], [types.int64])) + ops.assert_same_graph([sparse, a, b]) + ops.assert_same_graph([sparse, a, b], g0) + self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c]) + self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c], g1) + +ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) + + +class KernelLabelTest(test_util.TensorFlowTestCase): + + def testNoLabel(self): + with self.test_session(): + self.assertAllEqual("My label is: default", + test_kernel_label_op.kernel_label().eval()) + + def testLabelMap(self): + with self.test_session() as sess: + default_1 = test_kernel_label_op.kernel_label() + # pylint: disable=protected-access + with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): + overload_1_1 = test_kernel_label_op.kernel_label() + with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}): + overload_2 = test_kernel_label_op.kernel_label() + with sess.graph._kernel_label_map({"KernelLabel": ""}): + default_2 = test_kernel_label_op.kernel_label() + overload_1_2 = test_kernel_label_op.kernel_label() + # pylint: enable=protected-access + default_3 = test_kernel_label_op.kernel_label() + + self.assertAllEqual("My label is: default", default_1.eval()) + self.assertAllEqual("My label is: default", default_2.eval()) + self.assertAllEqual("My label is: default", default_3.eval()) + self.assertAllEqual("My label is: overload_1", overload_1_1.eval()) + self.assertAllEqual("My label is: overload_1", overload_1_2.eval()) + self.assertAllEqual("My label is: overload_2", overload_2.eval()) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc new file mode 100644 index 0000000000..5c1b4462d5 --- /dev/null +++ b/tensorflow/python/framework/python_op_gen.cc @@ -0,0 +1,678 @@ +#include "tensorflow/python/framework/python_op_gen.h" + +#include <stdio.h> +#include <unordered_map> +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace { + +const int kRightMargin = 78; + +bool IsPythonReserved(const string& s) { + static const std::set<string>* const kPythonReserved = new std::set<string>( + {// Keywords in Python, from: + // import keyword + // print keyword.kwlist + "and", "as", "assert", "break", "class", "continue", "def", "del", + "elif", "else", "except", "exec", "finally", "for", "from", "global", + "if", "import", "in", "is", "lambda", "not", "or", "pass", "print", + "raise", "return", "try", "while", "with", "yield", + // Built-in functions and types in Python, from: + // [x for x in dir(__builtins__) if not x[0].islower()] + "ArithmeticError", "AssertionError", "AttributeError", "BaseException", + "BufferError", "BytesWarning", "DeprecationWarning", "EOFError", + "Ellipsis", "EnvironmentError", "Exception", "False", + "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError", + "ImportError", "ImportWarning", "IndentationError", "IndexError", + "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError", + "NameError", "None", "NotImplemented", "NotImplementedError", "OSError", + "OverflowError", "PendingDeprecationWarning", "ReferenceError", + "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration", + "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError", + "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError", + "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError", + "UnicodeWarning", "UserWarning", "ValueError", "Warning", + "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__", + "__package__", + // Imports and symbols used in the generated code: + "_op_def_lib", "text_format", "op_def_pb2", "op_def_library", "ops"}); + + return kPythonReserved->count(s) > 0; +} + +// Add a _ to the end of s if necessary to avoid a Python keyword or built-in. +string AvoidPythonReserved(const string& s) { + if (IsPythonReserved(s)) return strings::StrCat(s, "_"); + return s; +} + +// Indent the first line by "initial" spaces and all following lines +// by "rest" spaces. +string Indent(int initial, int rest, StringPiece in) { + // TODO(josh11b): Also word-wrapping? + string copy(in.data(), in.size()); + str_util::StripTrailingWhitespace(©); + std::vector<string> v = str_util::Split(copy, '\n'); + + string result; + bool first = true; + for (const string& line : v) { + if (first) { + result = strings::StrCat(Spaces(initial), line, "\n"); + first = false; + } else { + if (line.empty()) { + strings::StrAppend(&result, "\n"); + } else { + strings::StrAppend(&result, Spaces(rest), line, "\n"); + } + } + } + return result; +} + +// Adds append to *dest, with a space if the first line will be <= width, +// or a newline otherwise. +void AppendWithinWidth(string* dest, StringPiece append, int width) { + auto first_line = append.find('\n'); + if (first_line == string::npos) first_line = append.size(); + if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) { + strings::StrAppend(dest, "\n", append); + } else { + strings::StrAppend(dest, " ", append); + } +} + +void RemoveDescriptionsFromOpDef(OpDef* op_def) { + for (int i = 0; i < op_def->input_arg_size(); ++i) { + op_def->mutable_input_arg(i)->clear_description(); + } + for (int i = 0; i < op_def->output_arg_size(); ++i) { + op_def->mutable_output_arg(i)->clear_description(); + } + for (int i = 0; i < op_def->attr_size(); ++i) { + op_def->mutable_attr(i)->clear_description(); + } + op_def->clear_summary(); + op_def->clear_description(); +} + +// Like DataTypeString() but uses the Python names for the +// float types. +string PythonDataTypeString(DataType dtype) { + switch (dtype) { + case DT_FLOAT: + return "float32"; + case DT_DOUBLE: + return "float64"; + default: + return DataTypeString(dtype); + } +} + +string TypeString(DataType dtype, bool ref) { + if (ref) { + return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`"); + } else { + return strings::StrCat("`", PythonDataTypeString(dtype), "`"); + } +} + +string TypeListString(const AttrValue& value) { + string ret; + for (int t : value.list().type()) { + if (!ret.empty()) strings::StrAppend(&ret, ", "); + DataType dtype = static_cast<DataType>(t); + if (IsRefType(dtype)) { + strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)), + " mutable"); + } else { + strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`"); + } + } + return ret; +} + +string SingleTensorName(DataType dtype, bool is_ref) { + const string type_str = TypeString(dtype, is_ref); + return strings::StrCat("A `Tensor` of type ", type_str, "."); +} + +const char kUnknownTensorType[] = {"A `Tensor`."}; + +string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg, + const std::unordered_map<string, string>& inferred_attrs, + bool is_output) { + if (!arg.number_attr().empty()) { + // N Tensors with the same type + const string* original_arg = + gtl::FindOrNull(inferred_attrs, arg.number_attr()); + string prefix; + if (original_arg == nullptr) { + prefix = strings::StrCat("A list of `", arg.number_attr(), "`"); + } else if (*original_arg == arg.name()) { + const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def); + if (attr->has_minimum() && attr->minimum() > 0) { + prefix = strings::StrCat("A list of at least ", attr->minimum()); + } else { + prefix = "A list of"; + } + } else { + prefix = strings::StrCat( + "A list with the same number of `Tensor` objects as `", + AvoidPythonReserved(*original_arg), "` of"); + } + + if (arg.type() != DT_INVALID) { + return strings::StrCat(prefix, " `Tensor` objects of type ", + TypeString(arg.type(), arg.is_ref()), "."); + } else { + original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr()); + if (arg.is_ref()) { + strings::StrAppend(&prefix, " mutable"); + } + if (original_arg == nullptr) { + return strings::StrCat(prefix, " `Tensor` objects of type ", + arg.type_attr(), "."); + } else if (*original_arg == arg.name()) { + const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def); + if (attr->has_allowed_values()) { + return strings::StrCat(prefix, + " `Tensor` objects of the same type in: ", + TypeListString(attr->allowed_values()), "."); + } else { + return strings::StrCat(prefix, " `Tensor` objects of the same type."); + } + } else { + return strings::StrCat(prefix, " `Tensor` objects of the same type as ", + AvoidPythonReserved(*original_arg), "."); + } + } + } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) { + const bool is_list = !arg.type_list_attr().empty(); + const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr(); + const OpDef::AttrDef* attr = FindAttr(attr_name, op_def); + const string mutable_str = arg.is_ref() ? "mutable " : ""; + const string prefix = + is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects") + : strings::StrCat("A ", mutable_str, "`Tensor`"); + const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name); + if (original_arg == nullptr) { + return strings::StrCat(prefix, " of type `", attr_name, "`."); + } else if (*original_arg == arg.name()) { + if (attr->has_allowed_values()) { + if (is_list) { + return strings::StrCat(prefix, " with types from: ", + TypeListString(attr->allowed_values()), "."); + } else { + return strings::StrCat( + prefix, is_output ? ". Has one of the following types: " + : ". Must be one of the following types: ", + TypeListString(attr->allowed_values()), "."); + } + } else { + return strings::StrCat(prefix, "."); + } + } else { + return strings::StrCat(prefix, + is_output ? ". Has the same type as `" + : ". Must have the same type as `", + AvoidPythonReserved(*original_arg), "`."); + } + } else { + return SingleTensorName(arg.type(), arg.is_ref()); + } +} + +void PrintReturns(const OpDef& op_def, + const std::vector<string>& output_type_string) { + DCHECK_EQ(op_def.output_arg_size(), output_type_string.size()); + const int num_outs = op_def.output_arg_size(); + printf("\n Returns:\n"); + if (num_outs == 0) { + printf(" The created Operation.\n"); + } else { + if (num_outs == 1) { + StringPiece description = op_def.output_arg(0).description(); + if (ConsumeEquals(&description)) { // Skip the generated type info. + printf("%s", Indent(4, 4, description).c_str()); + } else { + // Special case of one output, don't use the name of the output unless + // there is no description. + string desc = output_type_string.empty() ? kUnknownTensorType + : output_type_string[0]; + if (desc == kUnknownTensorType) { + // Special case where we don't understand how the output tensor type + // depends on the input tensor types, just use the output arg + // description if we can. + if (!description.empty()) { + desc = op_def.output_arg(0).description(); + } else if (!op_def.output_arg(0).name().empty()) { + desc = strings::StrCat(" The ", op_def.output_arg(0).name(), + " `Tensor`."); + } + } else if (!description.empty()) { + AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); + } + printf("%s", Indent(4, 4, desc).c_str()); + } + } else { + std::vector<string> out_names(num_outs); + for (int i = 0; i < num_outs; ++i) { + if (!op_def.output_arg(i).name().empty()) { + out_names[i] = op_def.output_arg(i).name(); + } else { + out_names[i] = strings::StrCat("output", i); + } + } + printf(" A tuple of `Tensor` objects (%s).\n", + str_util::Join(out_names, ", ").c_str()); + for (int i = 0; i < num_outs; ++i) { + string desc = strings::StrCat(out_names[i], ": "); + StringPiece description = op_def.output_arg(i).description(); + if (ConsumeEquals(&description)) { // Skip the generated type info. + strings::StrAppend(&desc, description); + } else { + const string type = static_cast<size_t>(i) < output_type_string.size() + ? output_type_string[i] + : kUnknownTensorType; + if (!description.empty()) { + if (type == kUnknownTensorType) { + // Special case where we don't understand how the output tensor + // type depends on the input tensor types, so we just use the + // output arg description. + strings::StrAppend(&desc, description); + } else { + strings::StrAppend(&desc, type, " ", description); + } + } else { + strings::StrAppend(&desc, type); + } + } + printf("%s", Indent(4, 6, desc).c_str()); + } + } + } +} + +string StringToPython(const string& str) { + return strings::StrCat("\"", str_util::CEscape(str), "\""); +} + +string DataTypeToPython(DataType dtype) { + return strings::StrCat("tf.", PythonDataTypeString(dtype)); +} + +string ShapeToPython(const TensorShapeProto& shape) { + string python = "["; + for (const auto& dim : shape.dim()) { + if (python.size() > 1) strings::StrAppend(&python, ", "); + if (!dim.name().empty()) { + strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ", + dim.size(), ")"); + } else { + strings::StrAppend(&python, dim.size()); + } + } + strings::StrAppend(&python, "]"); + return python; +} + +string AttrListToPython(const AttrValue& value) { + string ret; + if (value.list().s_size() > 0) { + for (int i = 0; i < value.list().s_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, StringToPython(value.list().s(i))); + } + } else if (value.list().i_size() > 0) { + for (int i = 0; i < value.list().i_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, value.list().i(i)); + } + } else if (value.list().f_size() > 0) { + for (int i = 0; i < value.list().f_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, value.list().f(i)); + } + } else if (value.list().b_size() > 0) { + for (int i = 0; i < value.list().b_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, value.list().b(i) ? "True" : "False"); + } + } else if (value.list().type_size() > 0) { + for (int i = 0; i < value.list().type_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, DataTypeToPython(value.list().type(i))); + } + } else if (value.list().shape_size() > 0) { + for (int i = 0; i < value.list().shape_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, ShapeToPython(value.list().shape(i))); + } + } + return ret; +} + +string AttrValueToPython(const string& type, const AttrValue& value) { + if (type == "string") { + return StringToPython(value.s()); + } else if (type == "int") { + return strings::StrCat(value.i()); + } else if (type == "float") { + return strings::StrCat(value.f()); + } else if (type == "bool") { + return value.b() ? "True" : "False"; + } else if (type == "type") { + return DataTypeToPython(value.type()); + } else if (type == "shape") { + return ShapeToPython(value.shape()); + } else { + return strings::StrCat("[", AttrListToPython(value), "]"); + } +} + +// Requires: ValidateOpDef(op_def).ok() +void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) { + // Map from attr name to the first input arg it is inferred from. + std::unordered_map<string, string> inferred_attrs; + // This has all the input args followed by those attrs that don't have + // defaults. + std::vector<string> args_no_default; + // The parameters with defaults (these have to be listed after those without). + // No input args are included, just attrs and the graph ("g") parameter. + std::vector<string> args_with_defaults; + for (int i = 0; i < op_def.input_arg_size(); ++i) { + const auto& arg(op_def.input_arg(i)); + args_no_default.push_back(arg.name()); + if (!arg.type_attr().empty()) { + gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name()); + } else if (!arg.type_list_attr().empty()) { + gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(), + arg.name()); + } + if (!arg.number_attr().empty()) { + gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name()); + } + } + for (int i = 0; i < op_def.attr_size(); ++i) { + const auto& attr(op_def.attr(i)); + // Do not add inferred attrs to the Python function signature. + if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) { + if (attr.has_default_value()) { + args_with_defaults.push_back(attr.name()); + } else { + args_no_default.push_back(attr.name()); + } + } + } + + // Save the list of attr parameters (attrs that won't be inferred), + // those with defaults go at the end. + std::vector<string> attrs; + // Get the attrs in the order we want by taking the attrs without defaults + // from the end of args_no_default, and adding args_no_default (before + // "g" gets added to args_no_default, so it only has attrs). + attrs.reserve(args_no_default.size() - op_def.input_arg_size() + + args_with_defaults.size()); + attrs.insert(attrs.end(), args_no_default.begin() + op_def.input_arg_size(), + args_no_default.end()); + attrs.insert(attrs.end(), args_with_defaults.begin(), + args_with_defaults.end()); + + std::vector<string> param_names; + param_names.reserve(args_no_default.size() + args_with_defaults.size()); + string parameters; + for (const string& name : args_no_default) { + if (!parameters.empty()) strings::StrAppend(¶meters, ", "); + const string param = AvoidPythonReserved(name); + strings::StrAppend(¶meters, param); + param_names.push_back(param); + } + for (const string& name : args_with_defaults) { + if (!parameters.empty()) strings::StrAppend(¶meters, ", "); + const string param = AvoidPythonReserved(name); + strings::StrAppend(¶meters, param, "=None"); + param_names.push_back(param); + } + const bool has_args = args_no_default.size() + args_with_defaults.size() > 0; + + // Print: def Function(parameters): + const string lower_op_name = strings::StrCat(is_hidden ? "_" : "", op_name); + + const string def_prefix = strings::StrCat("def ", lower_op_name, "("); + const string def_suffix = + strings::StrCat(parameters, has_args ? ", " : "", "name=None):"); + + printf("%s\n", WordWrap(def_prefix, def_suffix, kRightMargin).c_str()); + + // Format the Op's descriptions so that it can be a Python docstring. + string comment; + if (op_def.summary().empty()) { + comment = "TODO: add doc.\n"; + } else { + comment = strings::StrCat(op_def.summary(), "\n"); + if (!op_def.description().empty()) { + strings::StrAppend(&comment, "\n", Indent(2, 2, op_def.description())); + } + } + + printf(R"( r"""%s + Args: +)", + comment.c_str()); + + // Inputs + for (int i = 0; i < op_def.input_arg_size(); ++i) { + const auto& arg(op_def.input_arg(i)); + StringPiece description = op_def.input_arg(i).description(); + string desc; + if (ConsumeEquals(&description)) { // Skip the generated type info. + desc = strings::StrCat(param_names[i], ": "); + } else { + desc = strings::StrCat(param_names[i], ": ", + ArgTypeName(op_def, arg, inferred_attrs, false)); + } + if (!description.empty()) { + AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); + } + printf("%s", Indent(4, 6, desc).c_str()); + } + + // Attrs + for (const string& name : attrs) { + const auto& attr = *FindAttr(name, op_def); + string desc = strings::StrCat(AvoidPythonReserved(name), ": "); + + static const char* const kAttrTypeName[][2] = { + {"string", "`string`"}, + {"list(string)", "list of `strings`"}, + {"int", "`int`"}, + {"list(int)", "list of `ints`"}, + {"float", "`float`"}, + {"list(float)", "list of `floats`"}, + {"bool", "`bool`"}, + {"list(bool)", "list of `bools`"}, + {"type", "`tf.DType`"}, + {"list(type)", "list of `tf.DTypes`"}, + {"shape", "`tf.TensorShape` or list of `ints`"}, + {"list(shape)", + "list of shapes (each a `tf.TensorShape` or list of `ints`)"}, + }; + for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) { + if (attr.type() == kAttrTypeName[i][0]) { + string s; + if (attr.has_default_value()) { + s = strings::StrCat("optional ", kAttrTypeName[i][1]); + } else { + s = kAttrTypeName[i][1]; + } + if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) { + strings::StrAppend(&desc, "An ", s); + } else { + strings::StrAppend(&desc, "A ", s); + } + break; + } + } + + if (attr.has_allowed_values()) { + strings::StrAppend(&desc, " from: `", + AttrListToPython(attr.allowed_values()), "`"); + } + + if (attr.has_minimum()) { + if (attr.type() == "int") { + strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`"); + } else if (attr.minimum() > 0) { + strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`"); + } + } + + strings::StrAppend(&desc, "."); + + if (attr.has_default_value()) { + strings::StrAppend(&desc, " Defaults to `", + AttrValueToPython(attr.type(), attr.default_value()), + "`."); + } + + if (!attr.description().empty()) { + AppendWithinWidth(&desc, attr.description(), + kRightMargin - 4 /* indent */); + } + printf("%s", Indent(4, 6, desc).c_str()); + } + + printf(" name: A name for the operation (optional).\n"); + + std::vector<string> output_type_string; + output_type_string.reserve(op_def.output_arg_size()); + for (int i = 0; i < op_def.output_arg_size(); ++i) { + output_type_string.push_back( + ArgTypeName(op_def, op_def.output_arg(i), inferred_attrs, true)); + } + PrintReturns(op_def, output_type_string); + + string return_prefix = strings::StrCat(" return _op_def_lib.apply_op("); + string return_args = strings::StrCat("\"", op_def.name(), "\", "); + for (size_t i = 0; i < param_names.size(); ++i) { + strings::StrAppend(&return_args, param_names[i], "=", param_names[i], ", "); + } + strings::StrAppend(&return_args, "name=name)"); + + printf(R"( """ +%s +)", + // Wrap the arguments, and indent to the (. + WordWrap(return_prefix, return_args, kRightMargin).c_str()); + + printf("\n\n"); +} + +void GenerateLowerCaseOpName(const string& str, string* result) { + char joiner = '_'; + int last_index = str.size() - 1; + for (int i = 0; i <= last_index; ++i) { + char c = str[i]; + // Emit a joiner only if a previous-lower-to-now-upper or a + // now-upper-to-next-lower transition happens. + if (isupper(c) && (i > 0)) { + if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) { + result->push_back(joiner); + } + } + result->push_back(tolower(c)); + } +} + +} // namespace + +void PrintPythonOps(const OpList& ops, const string& hidden_ops, + bool require_shapes) { + // Header + // TODO(josh11b): Mention the library for which wrappers are being generated. + printf(R"("""Python wrappers around Brain. + +This file is MACHINE GENERATED! Do not edit. +""" + +from google.protobuf import text_format + +from tensorflow.core.framework import op_def_pb2 +from tensorflow.python.framework import op_def_registry +from tensorflow.python.framework import ops +from tensorflow.python.ops import op_def_library + + +)"); + + std::vector<string> hidden_vec = str_util::Split(hidden_ops, ','); + + // We'll make a copy of ops that filters out descriptions. + OpList cleaned_ops; + auto out = cleaned_ops.mutable_op(); + out->Reserve(ops.op_size()); + for (const auto& op_def : ops.op()) { + bool is_hidden = false; + for (const string& hidden : hidden_vec) { + if (op_def.name() == hidden) { + is_hidden = true; + break; + } + } + + // PrintPythonOp(op_def, is_hidden, op_def.name()); + string lower_case_name; + GenerateLowerCaseOpName(op_def.name(), &lower_case_name); + + // When users create custom python wrappers, they may link in the + // default op registry by accident, and because they can't + // enumerate all 'hidden' symbols, this guard is to prevent + // instantiating a python reserved word in their wrapper. + if (!is_hidden && IsPythonReserved(lower_case_name)) { + continue; + } + + PrintPythonOp(op_def, is_hidden, lower_case_name); + + if (!require_shapes) { + printf("ops.RegisterShape(\"%s\")(None)\n", op_def.name().c_str()); + } + + auto added = out->Add(); + *added = op_def; + RemoveDescriptionsFromOpDef(added); + } + + printf(R"(def _InitOpDefLibrary(): + op_list = op_def_pb2.OpList() + text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list) + op_def_registry.register_op_list(op_list) + op_def_lib = op_def_library.OpDefLibrary() + op_def_lib.add_op_list(op_list) + return op_def_lib + + +_InitOpDefLibrary.op_list_ascii = """%s""" + + +_op_def_lib = _InitOpDefLibrary() +)", + cleaned_ops.DebugString().c_str()); +} + +} // namespace tensorflow diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h new file mode 100644 index 0000000000..488f7431e0 --- /dev/null +++ b/tensorflow/python/framework/python_op_gen.h @@ -0,0 +1,17 @@ +#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_ +#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_ + +#include <string> +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Result is printed to stdout. hidden_ops should be a comma-separated +// list of Op names that should get a leading _ in the output. +void PrintPythonOps(const OpList& ops, const string& hidden_ops, + bool require_shapes); + +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_H_ diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc new file mode 100644 index 0000000000..29afe35598 --- /dev/null +++ b/tensorflow/python/framework/python_op_gen_main.cc @@ -0,0 +1,30 @@ +#include "tensorflow/python/framework/python_op_gen.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace { + +void PrintAllPythonOps(const char* hidden, bool require_shapes) { + OpList ops; + OpRegistry::Global()->Export(false, &ops); + PrintPythonOps(ops, hidden, require_shapes); +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char* argv[]) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc == 2) { + tensorflow::PrintAllPythonOps("", std::string(argv[1]) == "1"); + } else if (argc == 3) { + tensorflow::PrintAllPythonOps(argv[1], std::string(argv[2]) == "1"); + } else { + return -1; + } + return 0; +} diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py new file mode 100644 index 0000000000..d0ffee7042 --- /dev/null +++ b/tensorflow/python/framework/random_seed.py @@ -0,0 +1,136 @@ +"""For seeding individual ops based on a graph-level seed. +""" + +from tensorflow.python.framework import ops + + +_DEFAULT_GRAPH_SEED = 87654321 + + +def get_seed(op_seed): + """Returns the local seeds an operation should use given an op-specific seed. + + Given operation-specific seed, `op_seed`, this helper function returns two + seeds derived from graph-level and op-level seeds. Many random operations + internally use the two seeds to allow user to change the seed globally for a + graph, or for only specific operations. + + For details on how the graph-level seed interacts with op seeds, see + [`set_random_seed`](constant_op.md#set_random_seed). + + Args: + op_seed: integer. + + Returns: + A tuple of two integers that should be used for the local seed of this + operation. + """ + graph_seed = ops.get_default_graph().seed + if graph_seed is not None: + if op_seed is not None: + return graph_seed, op_seed + else: + return graph_seed, ops.get_default_graph()._last_id + else: + if op_seed is not None: + return _DEFAULT_GRAPH_SEED, op_seed + else: + return None, None + + +def set_random_seed(seed): + """Sets the graph-level random seed. + + Operations that rely on a random seed actually derive it from two seeds: + the graph-level and operation-level seeds. This sets the graph-level seed. + + Its interactions with operation-level seeds is as follows: + + 1. If neither the graph-level nor the operation seed is set: + A random seed is used for this op. + 2. If the graph-level seed is set, but the operation seed is not: + The system deterministically picks an operation seed in conjunction + with the graph-level seed so that it gets a unique random sequence. + 3. If the graph-level seed is not set, but the operation seed is set: + A default graph-level seed and the specified operation seed are used to + determine the random sequence. + 4. If both the graph-level and the operation seed are set: + Both seeds are used in conjunction to determine the random sequence. + + To illustrate the user-visible effects, consider these examples: + + To generate different sequences across sessions, set neither + graph-level nor op-level seeds: + + ```python + a = tf.random_uniform([1]) + b = tf.random_normal([1]) + + print "Session 1" + with tf.Session() as sess1: + print sess1.run(a) # generates 'A1' + print sess1.run(a) # generates 'A2' + print sess1.run(b) # generates 'B1' + print sess1.run(b) # generates 'B2' + + print "Session 2" + with tf.Session() as sess2: + print sess2.run(a) # generates 'A3' + print sess2.run(a) # generates 'A4' + print sess2.run(b) # generates 'B3' + print sess2.run(b) # generates 'B4' + ``` + + To generate the same repeatable sequence for an op across sessions, set the + seed for the op: + + ```python + a = tf.random_uniform([1], seed=1) + b = tf.random_normal([1]) + + # Repeatedly running this block with the same graph will generate the same + # sequence of values for 'a', but different sequences of values for 'b'. + print "Session 1" + with tf.Session() as sess1: + print sess1.run(a) # generates 'A1' + print sess1.run(a) # generates 'A2' + print sess1.run(b) # generates 'B1' + print sess1.run(b) # generates 'B2' + + print "Session 2" + with tf.Session() as sess2: + print sess2.run(a) # generates 'A1' + print sess2.run(a) # generates 'A2' + print sess2.run(b) # generates 'B3' + print sess2.run(b) # generates 'B4' + ``` + + To make the random sequences generated by all ops be repeatable across + sessions, set a graph-level seed: + + ```python + tf.set_random_seed(1234) + a = tf.random_uniform([1]) + b = tf.random_normal([1]) + + # Repeatedly running this block with the same graph will generate different + # sequences of 'a' and 'b'. + print "Session 1" + with tf.Session() as sess1: + print sess1.run(a) # generates 'A1' + print sess1.run(a) # generates 'A2' + print sess1.run(b) # generates 'B1' + print sess1.run(b) # generates 'B2' + + print "Session 2" + with tf.Session() as sess2: + print sess2.run(a) # generates 'A1' + print sess2.run(a) # generates 'A2' + print sess2.run(b) # generates 'B1' + print sess2.run(b) # generates 'B2' + ``` + + Args: + seed: integer. + """ + ops.get_default_graph().seed = seed diff --git a/tensorflow/python/framework/registry.py b/tensorflow/python/framework/registry.py new file mode 100644 index 0000000000..d9556f0a06 --- /dev/null +++ b/tensorflow/python/framework/registry.py @@ -0,0 +1,64 @@ +"""Registry mechanism for "registering" classes/functions for general use. + +This is typically used with a decorator that calls Register for adding +a class or function to a registry. +""" + +import traceback + +from tensorflow.python.platform import logging + + +# Registry mechanism below is based on mapreduce.python.mrpython.Register. +_LOCATION_TAG = "location" +_TYPE_TAG = "type" + + +class Registry(object): + """Provides a registry for saving objects.""" + + def __init__(self, name): + """Creates a new registry.""" + self._name = name + self._registry = dict() + + def register(self, candidate, name=None): + """Registers a Python object "candidate" for the given "name". + + Args: + candidate: the candidate object to add to the registry. + name: an optional string specifying the registry key for the candidate. + If None, candidate.__name__ will be used. + Raises: + KeyError: If same name is used twice. + """ + if not name: + name = candidate.__name__ + if name in self._registry: + (filename, line_number, function_name, _) = ( + self._registry[name][_LOCATION_TAG]) + raise KeyError("Registering two %s with name '%s' !" + "(Previous registration was in %s %s:%d)" % + (self._name, name, function_name, filename, line_number)) + + logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name) + # stack trace is [this_function, Register(), user_function,...] + # so the user function is #2. + stack = traceback.extract_stack() + self._registry[name] = {_TYPE_TAG: candidate, _LOCATION_TAG: stack[2]} + + def lookup(self, name): + """Looks up "name". + + Args: + name: a string specifying the registry key for the candidate. + Returns: + Registered object if found + Raises: + LookupError: if "name" has not been registered. + """ + if name in self._registry: + return self._registry[name][_TYPE_TAG] + else: + raise LookupError( + "%s registry has no entry for: %s" % (self._name, name)) diff --git a/tensorflow/python/framework/registry_test.py b/tensorflow/python/framework/registry_test.py new file mode 100644 index 0000000000..5b4f261ceb --- /dev/null +++ b/tensorflow/python/framework/registry_test.py @@ -0,0 +1,38 @@ +"""Tests for tensorflow.ops.registry.""" + +from tensorflow.python.framework import registry +from tensorflow.python.platform import googletest + + +class RegistryTest(googletest.TestCase): + + class Foo(object): + pass + + def testRegisterClass(self): + myreg = registry.Registry('testfoo') + with self.assertRaises(LookupError): + myreg.lookup('Foo') + myreg.register(RegistryTest.Foo, 'Foo') + assert myreg.lookup('Foo') == RegistryTest.Foo + + def testRegisterFunction(self): + myreg = registry.Registry('testbar') + with self.assertRaises(LookupError): + myreg.lookup('Bar') + myreg.register(bar, 'Bar') + assert myreg.lookup('Bar') == bar + + def testDuplicate(self): + myreg = registry.Registry('testbar') + myreg.register(bar, 'Bar') + with self.assertRaises(KeyError): + myreg.register(bar, 'Bar') + + +def bar(): + pass + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py new file mode 100644 index 0000000000..d4f27696d4 --- /dev/null +++ b/tensorflow/python/framework/tensor_shape.py @@ -0,0 +1,743 @@ +"""Helper classes for tensor shape inference.""" +import tensorflow.python.platform + + +class Dimension(object): + """Represents the value of one dimension in a TensorShape.""" + + def __init__(self, value): + """Creates a new Dimension with the given value.""" + if value is None: + self._value = None + else: + self._value = int(value) + + def __repr__(self): + return "Dimension(%s)" % repr(self._value) + + def __eq__(self, other): + """Returns true if `other` has the same known value as this Dimension.""" + other = as_dimension(other) + if self._value is None or other.value is None: + return None + return self._value == other.value + + def __ne__(self, other): + """Returns true if `other` has a different known value from `self`.""" + other = as_dimension(other) + if self._value is None or other.value is None: + return None + return self._value != other.value + + def __int__(self): + return self._value + + @property + def value(self): + """The value of this dimension, or None if it is unknown.""" + return self._value + + def is_compatible_with(self, other): + """Returns true if `other` is compatible with this Dimension. + + Two known Dimensions are compatible if they have the same value. + An unknown Dimension is compatible with all other Dimensions. + + Args: + other: Another Dimension. + + Returns: + True if this Dimension and `other` are compatible. + """ + other = as_dimension(other) + return (self._value is None + or other.value is None + or self._value == other.value) + + def assert_is_compatible_with(self, other): + """Raises an exception if `other` is not compatible with this Dimension. + + Args: + other: Another Dimension. + + Raises: + ValueError: If `self` and `other` are not compatible (see + is_compatible_with). + """ + if not self.is_compatible_with(other): + raise ValueError("Dimensions %s and %s are not compatible" + % (self, other)) + + def merge_with(self, other): + """Returns a Dimension that combines the information in `self` and `other`. + + Dimensions are combined as follows: + + Dimension(n) .merge_with(Dimension(n)) == Dimension(n) + Dimension(n) .merge_with(Dimension(None)) == Dimension(n) + Dimension(None).merge_with(Dimension(n)) == Dimension(n) + Dimension(None).merge_with(Dimension(None)) == Dimension(None) + Dimension(n) .merge_with(Dimension(m)) raises ValueError for n != m + + Args: + other: Another Dimension. + + Returns: + A Dimension containing the combined information of `self` and + `other`. + + Raises: + ValueError: If `self` and `other` are not compatible (see + is_compatible_with). + """ + other = as_dimension(other) + self.assert_is_compatible_with(other) + if self._value is None: + return Dimension(other.value) + else: + return Dimension(self._value) + + def __add__(self, other): + """Returns the sum of `self` and `other`. + + Dimensions are summed as follows: + + Dimension(m) + Dimension(n) == Dimension(m + n) + Dimension(m) + Dimension(None) == Dimension(None) + Dimension(None) + Dimension(n) == Dimension(None) + Dimension(None) + Dimension(None) == Dimension(None) + + Args: + other: Another Dimension. + + Returns: + A Dimension whose value is the sum of `self` and `other`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(self._value + other.value) + + def __sub__(self, other): + """Returns the subtraction of `other` from `self`. + + Dimensions are subtracted as follows: + + Dimension(m) - Dimension(n) == Dimension(m - n) + Dimension(m) - Dimension(None) == Dimension(None) + Dimension(None) - Dimension(n) == Dimension(None) + Dimension(None) - Dimension(None) == Dimension(None) + + Args: + other: Another Dimension. + + Returns: + A Dimension whose value is the subtraction of sum of `other` from `self`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(self._value - other.value) + + def __mul__(self, other): + """Returns the product of `self` and `other`. + + Dimensions are summed as follows: + + Dimension(m) * Dimension(n) == Dimension(m * n) + Dimension(m) * Dimension(None) == Dimension(None) + Dimension(None) * Dimension(n) == Dimension(None) + Dimension(None) * Dimension(None) == Dimension(None) + + Args: + other: Another Dimension. + + Returns: + A Dimension whose value is the sum of `self` and `other`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(self._value * other.value) + + def __div__(self, other): + """Returns the quotient of `self` and `other`. + + Dimensions are summed as follows: + + Dimension(m) / Dimension(n) == Dimension(m / n) + Dimension(m) / Dimension(None) == Dimension(None) + Dimension(None) / Dimension(n) == Dimension(None) + Dimension(None) / Dimension(None) == Dimension(None) + + Args: + other: Another Dimension. + + Returns: + A Dimension whose value is the sum of `self` and `other`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(self._value / other.value) + + def __mod__(self, other): + """Returns `self` modulo `other. + + Dimension moduli are computed as follows: + + Dimension(m) % Dimension(n) == Dimension(m % n) + Dimension(m) % Dimension(None) == Dimension(None) + Dimension(None) % Dimension(n) == Dimension(None) + Dimension(None) % Dimension(None) == Dimension(None) + + Args: + other: Another Dimension. + + Returns: + A Dimension whose value is `self` modulo `other`. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return Dimension(None) + else: + return Dimension(self._value % other.value) + + def __lt__(self, other): + """Returns True if `self` is known to be less than `other`. + + Dimensions are compared as follows: + + Dimension(m) < Dimension(n) == m < n + Dimension(m) < Dimension(None) == None + Dimension(None) < Dimension(n) == None + Dimension(None) < Dimension(None) == None + + Args: + other: Another Dimension. + + Returns: + The value of `self.value < other.value` if both are known, otherwise + None. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return None + else: + return self._value < other.value + + def __le__(self, other): + """Returns True if `self` is known to be less than or equal to `other`. + + Dimensions are compared as follows: + + Dimension(m) <= Dimension(n) == m <= n + Dimension(m) <= Dimension(None) == None + Dimension(None) <= Dimension(n) == None + Dimension(None) <= Dimension(None) == None + + Args: + other: Another Dimension. + + Returns: + The value of `self.value <= other.value` if both are known, otherwise + None. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return None + else: + return self._value <= other.value + + def __gt__(self, other): + """Returns True if `self` is known to be greater than `other`. + + Dimensions are compared as follows: + + Dimension(m) > Dimension(n) == m > n + Dimension(m) > Dimension(None) == None + Dimension(None) > Dimension(n) == None + Dimension(None) > Dimension(None) == None + + Args: + other: Another Dimension. + + Returns: + The value of `self.value > other.value` if both are known, otherwise + None. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return None + else: + return self._value > other.value + + def __ge__(self, other): + """Returns True if `self` is known to be greater than or equal to `other`. + + Dimensions are compared as follows: + + Dimension(m) >= Dimension(n) == m >= n + Dimension(m) >= Dimension(None) == None + Dimension(None) >= Dimension(n) == None + Dimension(None) >= Dimension(None) == None + + Args: + other: Another Dimension. + + Returns: + The value of `self.value >= other.value` if both are known, otherwise + None. + """ + other = as_dimension(other) + if self._value is None or other.value is None: + return None + else: + return self._value >= other.value + + +def as_dimension(value): + """Converts the given value to a Dimension. + + A Dimenson input will be returned unmodified. + An input of `None` will be converted to an unknown Dimension. + An integer input will be converted to a Dimension with that value. + + Args: + value: The value to be converted. + + Returns: + A Dimension corresponding to the given value. + """ + if isinstance(value, Dimension): + return value + else: + return Dimension(value) + + +class TensorShape(object): + """Represents the shape of a `Tensor`. + + A `TensorShape` represents a possibly-partial shape specification for a + `Tensor`. It may be one of the following: + + * *Fully-known shape:* has a known number of dimensions and a known size + for each dimension. + * *Partially-known shape:* has a known number of dimensions, and an unknown + size for one or more dimension. + * *Unknown shape:* has an unknown number of dimensions, and an unknown + size in all dimensions. + + If a tensor is produced by an operation of type `"Foo"`, its shape + may be inferred if there is a registered shape function for + `"Foo"`. See [`tf.RegisterShape()`](framework.md#RegisterShape) + for details of shape + functions and how to register them. Alternatively, the shape may be set + explicitly using [`Tensor.set_shape()`](framework.md#Tensor.set_shape). + + @@merge_with + @@concatenate + + @@ndims + @@dims + @@as_list + @@is_compatible_with + @@is_fully_defined + + @@with_rank + @@with_rank_at_least + @@with_rank_at_most + + @@assert_has_rank + @@assert_same_rank + @@assert_is_compatible_with + @@assert_is_fully_defined + """ + + def __init__(self, dims): + """Creates a new TensorShape with the given dimensions. + + Args: + dims: A list of Dimensions, or None if the shape is unspecified. + DEPRECATED: A single integer is treated as a singleton list. + """ + # TODO(irving): Eliminate the single integer special case. + if dims is None: + self._dims = None + else: + try: + dims_iter = iter(dims) + except TypeError: + # Treat as a singleton dimension + self._dims = [as_dimension(dims)] + else: + # Got a list of dimensions + self._dims = map(as_dimension, dims_iter) + + def __repr__(self): + return "TensorShape(%s)" % str(self._dims) + + @property + def dims(self): + """Returns a list of Dimensions, or None if the shape is unspecified.""" + return self._dims + + @property + def ndims(self): + """Returns the rank of this shape, or None if it is unspecified.""" + if self._dims is None: + return None + else: + return len(self._dims) + + def __len__(self): + """Returns the rank of this shape, or raises ValueError if unspecified.""" + if self._dims is None: + raise ValueError("Cannot take the length of Shape with unknown rank.") + return len(self._dims) + + def __nonzero__(self): + """Returns True if this shape contains non-zero information.""" + return self._dims is not None + + def __getitem__(self, key): + """Returns the value of a dimension or a shape, depending on the key. + + Args: + key: If `key` is an integer, returns the dimension at that index; + otherwise if `key` is a slice, returns a TensorShape whose + dimensions are those selected by the slice from `self`. + + Returns: + A dimension if `key` is an integer, or a `TensorShape` if `key` is a + slice. + + Raises: + ValueError: If `key` is a slice, and any of its elements are negative, or + if `self` is completely unknown and the step is set. + """ + if self._dims is not None: + if isinstance(key, slice): + return TensorShape(self._dims[key]) + else: + return self._dims[key] + else: + if isinstance(key, slice): + start = key.start if key.start is not None else 0 + stop = key.stop + + if key.step is not None: + # TODO(mrry): Handle these maybe. + raise ValueError("Steps are not yet handled") + if stop is None: + # NOTE(mrry): This implies that TensorShape(None) is compatible with + # TensorShape(None)[1:], which is obviously not true. It would be + # possible to track the number of dimensions symbolically, + # and perhaps we should do that. + return unknown_shape() + elif start < 0 or stop < 0: + # TODO(mrry): Handle this better, as it will be useful for handling + # suffixes of otherwise unknown shapes. + return unknown_shape() + else: + return unknown_shape(ndims=stop-start) + else: + return Dimension(None) + + def num_elements(self): + """Returns the total number of elements, or none for incomplete shapes.""" + if self.is_fully_defined(): + size = 1 + for dim in self._dims: + size *= dim.value + return size + else: + return None + + def merge_with(self, other): + """Returns a `TensorShape` combining the information in `self` and `other`. + + The dimensions in `self` and `other` are merged elementwise, + according to the rules defined for `Dimension.merge_with()`. + + Args: + other: Another `TensorShape`. + + Returns: + A `TensorShape` containing the combined information of `self` and + `other`. + + Raises: + ValueError: If `self` and `other` are not compatible. + """ + other = as_shape(other) + if self._dims is None: + return other + else: + self.assert_same_rank(other) + new_dims = [] + for i, dim in enumerate(self._dims): + new_dims.append(dim.merge_with(other[i])) + return TensorShape(new_dims) + + def concatenate(self, other): + """Returns the concatenation of the dimension in `self` and `other`. + + *N.B.* If either `self` or `other` is completely unknown, + concatenation will discard information about the other shape. In + future, we might support concatenation that preserves this + information for use with slicing. + + Args: + other: Another `TensorShape`. + + Returns: + A `TensorShape` whose dimensions are the concatenation of the + dimensions in `self` and `other`. + """ + # TODO(mrry): Handle the case where we concatenate a known shape with a + # completely unknown shape, so that we can use the partial information. + other = as_shape(other) + if self._dims is None or other.dims is None: + return unknown_shape() + else: + return TensorShape(self._dims + other.dims) + + def assert_same_rank(self, other): + """Raises an exception if `self` and `other` do not have compatible ranks. + + Args: + other: Another `TensorShape`. + + Raises: + ValueError: If `self` and `other` do not represent shapes with the + same rank. + """ + other = as_shape(other) + if self.ndims is not None and other.ndims is not None: + if self.ndims != other.ndims: + raise ValueError( + "Shapes %s and %s must have the same rank" % (self, other)) + + def assert_has_rank(self, rank): + """Raises an exception if `self` is not compatible with the given `rank`. + + Args: + rank: An integer. + + Raises: + ValueError: If `self` does not represent a shape with the given `rank`. + """ + if self.ndims not in (None, rank): + raise ValueError("Shape %s must have rank %d" % (self, rank)) + + def with_rank(self, rank): + """Returns a shape based on `self` with the given rank. + + This method promotes a completely unknown shape to one with a + known rank. + + Args: + rank: An integer. + + Returns: + A shape that is at least as specific as `self` with the given rank. + + Raises: + ValueError: If `self` does not represent a shape with the given `rank`. + """ + return self.merge_with(unknown_shape(ndims=rank)) + + def with_rank_at_least(self, rank): + """Returns a shape based on `self` with at least the given rank. + + Args: + rank: An integer. + + Returns: + A shape that is at least as specific as `self` with at least the given + rank. + + Raises: + ValueError: If `self` does not represent a shape with at least the given + `rank`. + """ + if self.ndims is not None and self.ndims < rank: + raise ValueError("Shape %s must have rank at least %d" % (self, rank)) + else: + return self + + def with_rank_at_most(self, rank): + """Returns a shape based on `self` with at most the given rank. + + Args: + rank: An integer. + + Returns: + A shape that is at least as specific as `self` with at most the given + rank. + + Raises: + ValueError: If `self` does not represent a shape with at most the given + `rank`. + """ + if self.ndims is not None and self.ndims > rank: + raise ValueError("Shape %s must have rank at most %d" % (self, rank)) + else: + return self + + def is_compatible_with(self, other): + """Returns True iff `self` is compatible with `other`. + + Two possibly-partially-defined shapes are compatible if there + exists a fully-defined shape that both shapes can represent. Thus, + compatibility allows the shape inference code to reason about + partially-defined shapes. For example: + + * TensorShape(None) is compatible with all shapes. + + * TensorShape([None, None]) is compatible with all two-dimensional + shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is + not compatible with, for example, TensorShape([None]) or + TensorShape([None, None, None]). + + * TensorShape([32, None]) is compatible with all two-dimensional shapes + with size 32 in the 0th dimension, and also TensorShape([None, None]) + and TensorShape(None). It is not compatible with, for example, + TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]). + + * TensorShape([32, 784]) is compatible with itself, and also + TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None, + None]) and TensorShape(None). It is not compatible with, for example, + TensorShape([32, 1, 784]) or TensorShape([None]). + + The compatibility relation is reflexive and symmetric, but not + transitive. For example, TensorShape([32, 784]) is compatible with + TensorShape(None), and TensorShape(None) is compatible with + TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with + TensorShape([4, 4]). + + Args: + other: Another TensorShape. + + Returns: + True iff `self` is compatible with `other`. + + """ + other = as_shape(other) + if self._dims is not None and other.dims is not None: + if self.ndims != other.ndims: + return False + for x_dim, y_dim in zip(self._dims, other.dims): + if not x_dim.is_compatible_with(y_dim): + return False + return True + + def assert_is_compatible_with(self, other): + """Raises exception if `self` and `other` do not represent the same shape. + + This method can be used to assert that there exists a shape that both + `self` and `other` represent. + + Args: + other: Another TensorShape. + + Raises: + ValueError: If `self` and `other` do not represent the same shape. + """ + if not self.is_compatible_with(other): + raise ValueError("Shapes %s and %s are incompatible" % (self, other)) + + def is_fully_defined(self): + """Returns True iff `self` is fully defined in every dimension.""" + return (self._dims is not None + and all(dim.value is not None for dim in self._dims)) + + def assert_is_fully_defined(self): + """Raises an exception if `self` is not fully defined in every dimension. + + Raises: + ValueError: If `self` does not have a known value for every dimension. + """ + if not self.is_fully_defined(): + raise ValueError("Shape %s is not fully defined" % self) + + def as_dimension_list(self): + """DEPRECATED: use as_list().""" + self.assert_is_fully_defined() + return self.as_list() + + def as_list(self): + """Returns a list of integers or None for each dimension.""" + return [dim.value for dim in self._dims] + + def __eq__(self, other): + """Returns True if `self` is equivalent to `other`.""" + other = as_shape(other) + return self._dims == other.dims + + def __ne__(self, other): + """Returns True if `self` is known to be different from `other`.""" + other = as_shape(other) + if self.ndims is None or other.ndims is None: + raise ValueError("The inequality of unknown TensorShapes is undefined.") + if self.ndims != other.ndims: + return True + return self._dims != other.dims + + +def as_shape(shape): + """Converts the given object to a TensorShape.""" + if isinstance(shape, TensorShape): + return shape + else: + return TensorShape(shape) + + +def unknown_shape(ndims=None): + """Returns an unknown TensorShape, optionally with a known rank. + + Args: + ndims: (Optional) If specified, the number of dimensions in the shape. + + Returns: + An unknown TensorShape. + """ + if ndims is None: + return TensorShape(None) + else: + return TensorShape([Dimension(None) for _ in range(ndims)]) + + +def scalar(): + """Returns a shape representing a scalar.""" + return TensorShape([]) + + +def vector(length): + """Returns a shape representing a vector. + + Args: + length: The length of the vector, which may be None if unknown. + + Returns: + A TensorShape representing a vector of the given length. + """ + return TensorShape([length]) + + +def matrix(rows, cols): + """Returns a shape representing a matrix. + + Args: + rows: The number of rows in the matrix, which may be None if unknown. + cols: The number of columns in the matrix, which may be None if unknown. + + Returns: + A TensorShape representing a matrix of the given size. + """ + return TensorShape([rows, cols]) diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py new file mode 100644 index 0000000000..9743a8d199 --- /dev/null +++ b/tensorflow/python/framework/tensor_shape_test.py @@ -0,0 +1,232 @@ +"""Functional tests for shape inference helper classes.""" +import tensorflow.python.platform + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class DimensionTest(test_util.TensorFlowTestCase): + + def testDimension(self): + dim = tensor_shape.Dimension(12) + self.assertEqual(12, dim.value) + self.assertEqual(12, int(dim)) + self.assertEqual(dim, tensor_shape.Dimension(12)) + self.assertEqual(tensor_shape.Dimension(15), + dim + tensor_shape.Dimension(3)) + self.assertEqual(tensor_shape.Dimension(15), dim + 3) + self.assertEqual(tensor_shape.Dimension(24), + dim * tensor_shape.Dimension(2)) + self.assertEqual(tensor_shape.Dimension(24), dim * 2) + self.assertEqual(tensor_shape.Dimension(6), dim / tensor_shape.Dimension(2)) + self.assertEqual(tensor_shape.Dimension(6), dim / 2) + self.assertEqual(tensor_shape.Dimension(12), + dim.merge_with(tensor_shape.Dimension(12))) + self.assertEqual(tensor_shape.Dimension(12), dim.merge_with(12)) + self.assertLess(tensor_shape.Dimension(12), tensor_shape.Dimension(13)) + self.assertGreater(tensor_shape.Dimension(13), tensor_shape.Dimension(12)) + self.assertLessEqual(tensor_shape.Dimension(12), tensor_shape.Dimension(12)) + self.assertLessEqual(tensor_shape.Dimension(12), tensor_shape.Dimension(13)) + self.assertGreater(tensor_shape.Dimension(13), tensor_shape.Dimension(12)) + self.assertGreaterEqual(tensor_shape.Dimension(12), + tensor_shape.Dimension(12)) + self.assertGreaterEqual(tensor_shape.Dimension(13), + tensor_shape.Dimension(12)) + with self.assertRaises(ValueError): + dim.merge_with(tensor_shape.Dimension(13)) + + def testUnknownDimension(self): + dim = tensor_shape.Dimension(None) + self.assertIs(None, dim.value) + self.assertEqual(dim.value, tensor_shape.Dimension(None).value) + self.assertEqual(tensor_shape.Dimension(None).value, + (dim + tensor_shape.Dimension(None)).value) + self.assertEqual(tensor_shape.Dimension(None).value, + (dim * tensor_shape.Dimension(None)).value) + self.assertEqual(tensor_shape.Dimension(None).value, + (dim / tensor_shape.Dimension(None)).value) + self.assertEqual(tensor_shape.Dimension(None).value, + dim.merge_with(tensor_shape.Dimension(None)).value) + self.assertIs(None, + tensor_shape.Dimension(None) < tensor_shape.Dimension(None)) + self.assertIs(None, + tensor_shape.Dimension(None) <= tensor_shape.Dimension(None)) + self.assertIs(None, + tensor_shape.Dimension(None) > tensor_shape.Dimension(None)) + self.assertIs(None, + tensor_shape.Dimension(None) >= tensor_shape.Dimension(None)) + + def testKnownAndUnknownDimensions(self): + known = tensor_shape.Dimension(12) + unknown = tensor_shape.Dimension(None) + self.assertEqual( + tensor_shape.Dimension(None).value, (known + unknown).value) + self.assertEqual( + tensor_shape.Dimension(None).value, (unknown + known).value) + self.assertEqual( + tensor_shape.Dimension(None).value, (known * unknown).value) + self.assertEqual( + tensor_shape.Dimension(None).value, (unknown * known).value) + self.assertEqual( + tensor_shape.Dimension(None).value, (known / unknown).value) + self.assertEqual( + tensor_shape.Dimension(None).value, (unknown / known).value) + self.assertEqual( + tensor_shape.Dimension(12), known.merge_with(unknown)) + self.assertEqual( + tensor_shape.Dimension(12), unknown.merge_with(known)) + self.assertIs(None, + tensor_shape.Dimension(12) < tensor_shape.Dimension(None)) + self.assertIs(None, + tensor_shape.Dimension(12) <= tensor_shape.Dimension(None)) + self.assertIs(None, + tensor_shape.Dimension(12) > tensor_shape.Dimension(None)) + self.assertIs(None, + tensor_shape.Dimension(12) >= tensor_shape.Dimension(None)) + self.assertIs(None, + tensor_shape.Dimension(None) < tensor_shape.Dimension(12)) + self.assertIs(None, + tensor_shape.Dimension(None) <= tensor_shape.Dimension(12)) + self.assertIs(None, + tensor_shape.Dimension(None) > tensor_shape.Dimension(12)) + self.assertIs(None, + tensor_shape.Dimension(None) >= tensor_shape.Dimension(12)) + + def testAsDimension(self): + self.assertEqual(tensor_shape.Dimension(12), + tensor_shape.as_dimension(tensor_shape.Dimension(12))) + self.assertEqual(tensor_shape.Dimension(12), tensor_shape.as_dimension(12)) + self.assertEqual( + tensor_shape.Dimension(None).value, + tensor_shape.as_dimension(tensor_shape.Dimension(None)).value) + self.assertEqual(tensor_shape.Dimension(None).value, + tensor_shape.as_dimension(None).value) + + def testEquality(self): + self.assertTrue(tensor_shape.Dimension(12) == tensor_shape.Dimension(12)) + self.assertFalse(tensor_shape.Dimension(12) == tensor_shape.Dimension(13)) + self.assertIs(None, + tensor_shape.Dimension(12) == tensor_shape.Dimension(None)) + self.assertIs(None, + tensor_shape.Dimension(None) == tensor_shape.Dimension(12)) + self.assertIs(None, + tensor_shape.Dimension(None) == tensor_shape.Dimension(None)) + + def testInequality(self): + self.assertTrue(tensor_shape.Dimension(12) != tensor_shape.Dimension(13)) + self.assertFalse(tensor_shape.Dimension(12) != tensor_shape.Dimension(12)) + self.assertIs(None, + tensor_shape.Dimension(12) != tensor_shape.Dimension(None)) + self.assertIs(None, + tensor_shape.Dimension(None) != tensor_shape.Dimension(12)) + self.assertIs(None, + tensor_shape.Dimension(None) != tensor_shape.Dimension(None)) + + +class ShapeTest(test_util.TensorFlowTestCase): + + def testUnknownShape(self): + s = tensor_shape.TensorShape(None) + with self.assertRaises(ValueError): + s.assert_is_fully_defined() + self.assertIs(None, s.ndims) + with self.assertRaises(ValueError): + len(s) + self.assertFalse(s) + self.assertIs(None, s.dims) + + def testFullyDefinedShape(self): + s = tensor_shape.TensorShape([tensor_shape.Dimension(3), + tensor_shape.Dimension(4), + tensor_shape.Dimension(7)]) + s.assert_is_fully_defined() + self.assertEqual(3, s.ndims) + self.assertEqual(3, len(s)) + self.assertTrue(s) + s.assert_has_rank(3) + self.assertEqual([tensor_shape.Dimension(3), + tensor_shape.Dimension(4), + tensor_shape.Dimension(7)], s.dims) + self.assertEqual(tensor_shape.Dimension(3), s[0]) + self.assertEqual(tensor_shape.Dimension(4), s[1]) + self.assertEqual(tensor_shape.Dimension(7), s[2]) + self.assertEqual([3, 4, 7], s.as_list()) + s.assert_is_compatible_with([3, 4, 7]) + s.assert_same_rank([6, 3, 7]) + + def testPartiallyDefinedShape(self): + s = tensor_shape.TensorShape([tensor_shape.Dimension(3), + tensor_shape.Dimension(None), + tensor_shape.Dimension(7)]) + with self.assertRaises(ValueError): + s.assert_is_fully_defined() + self.assertEqual(3, s.ndims) + self.assertEqual(3, len(s)) + self.assertTrue(s) + s.assert_has_rank(3) + self.assertEqual(tensor_shape.Dimension(3), s[0]) + self.assertEqual(tensor_shape.Dimension(None).value, s[1].value) + self.assertEqual(tensor_shape.Dimension(7), s[2]) + s.assert_same_rank([6, 3, 7]) + + def testMergeFullShapes(self): + self.assertEqual([3, 4, 7], + tensor_shape.TensorShape([3, 4, 7]).merge_with( + tensor_shape.TensorShape([3, 4, 7])).as_list()) + with self.assertRaises(ValueError): + tensor_shape.TensorShape([3, 4, 7]).merge_with( + tensor_shape.TensorShape([6, 3, 7])) + + def testMergePartialShapes(self): + s1 = tensor_shape.TensorShape([tensor_shape.Dimension(3), + tensor_shape.Dimension(None), + tensor_shape.Dimension(7)]) + s2 = tensor_shape.TensorShape([tensor_shape.Dimension(None), + tensor_shape.Dimension(4), + tensor_shape.Dimension(7)]) + self.assertEqual([3, 4, 7], s1.merge_with(s2).as_list()) + + def testMergeFullAndUnknownShape(self): + self.assertEqual([3, 4, 7], + tensor_shape.TensorShape([3, 4, 7]).merge_with( + tensor_shape.TensorShape(None)).as_list()) + + def testSlice(self): + known = tensor_shape.TensorShape([0, 1, 2, 3, 4]) + self.assertEqual(tensor_shape.Dimension(2), known[2]) + tensor_shape.TensorShape([1, 2, 3]).assert_is_compatible_with(known[1:4]) + + unknown = tensor_shape.TensorShape(None) + self.assertEqual(tensor_shape.Dimension(None).value, unknown[2].value) + tensor_shape.TensorShape( + [None, None, None]).assert_is_compatible_with(unknown[1:4]) + + def testConcatenate(self): + tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with( + tensor_shape.TensorShape([1, 2]).concatenate( + tensor_shape.TensorShape([3, 4]))) + tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with( + tensor_shape.TensorShape([1, 2]).concatenate( + tensor_shape.TensorShape(None))) + tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with( + tensor_shape.TensorShape(None).concatenate( + tensor_shape.TensorShape([3, 4]))) + tensor_shape.TensorShape([1, 2, 3, 4]).assert_is_compatible_with( + tensor_shape.TensorShape(None).concatenate( + tensor_shape.TensorShape(None))) + tensor_shape.TensorShape([1, 2, 3]).assert_is_compatible_with( + tensor_shape.TensorShape([1, 2]).concatenate( + tensor_shape.Dimension(3))) + + def testHelpers(self): + tensor_shape.TensorShape([]).assert_is_compatible_with( + tensor_shape.scalar()) + tensor_shape.TensorShape([37]).assert_is_compatible_with( + tensor_shape.vector(37)) + tensor_shape.TensorShape( + [94, 43]).assert_is_compatible_with(tensor_shape.matrix(94, 43)) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py new file mode 100644 index 0000000000..81ed54c473 --- /dev/null +++ b/tensorflow/python/framework/tensor_util.py @@ -0,0 +1,511 @@ +"""Utilities to create TensorProtos.""" +import numbers +import tensorflow.python.platform +import numpy as np + +from tensorflow.core.framework import tensor_pb2 +from tensorflow.core.framework import tensor_shape_pb2 + +# TODO(opensource): Add support for pyx_library in the open-source build. +# For now, we use the slow versions that fast_tensor_util replaces. +# pylint: disable=g-import-not-at-top +try: + from tensorflow.python.framework import fast_tensor_util + _FAST_TENSOR_UTIL_AVAILABLE = True +except ImportError: + _FAST_TENSOR_UTIL_AVAILABLE = False + +from tensorflow.python.framework import ops +from tensorflow.python.framework import types +# pylint: enable=g-import-not-at-top + + +if _FAST_TENSOR_UTIL_AVAILABLE: + _NP_TO_APPEND_FN = { + np.float32: fast_tensor_util.AppendFloat32ArrayToTensorProto, + np.float64: fast_tensor_util.AppendFloat64ArrayToTensorProto, + np.int32: fast_tensor_util.AppendInt32ArrayToTensorProto, + np.int64: fast_tensor_util.AppendInt64ArrayToTensorProto, + np.uint8: fast_tensor_util.AppendUInt8ArrayToTensorProto, + np.int16: fast_tensor_util.AppendInt16ArrayToTensorProto, + np.int8: fast_tensor_util.AppendInt8ArrayToTensorProto, + np.complex64: fast_tensor_util.AppendComplex64ArrayToTensorProto, + np.complex128: fast_tensor_util.AppendComplex128ArrayToTensorProto, + np.object: fast_tensor_util.AppendObjectArrayToTensorProto, + np.bool: fast_tensor_util.AppendBoolArrayToTensorProto, + types.qint8.as_numpy_dtype: + fast_tensor_util.AppendInt8ArrayToTensorProto, + types.quint8.as_numpy_dtype: + fast_tensor_util.AppendUInt8ArrayToTensorProto, + types.qint32.as_numpy_dtype: + fast_tensor_util.AppendInt32ArrayToTensorProto, + # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16. + } +else: + + def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values): + tensor_proto.float_val.extend([np.asscalar(x) for x in proto_values]) + + def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values): + tensor_proto.double_val.extend([np.asscalar(x) for x in proto_values]) + + def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values): + tensor_proto.int_val.extend([np.asscalar(x) for x in proto_values]) + + def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values): + tensor_proto.int64_val.extend([np.asscalar(x) for x in proto_values]) + + def SlowAppendComplexArrayToTensorProto(tensor_proto, proto_values): + tensor_proto.scomplex_val.extend([np.asscalar(v) + for x in proto_values + for v in [x.real, x.imag]]) + + def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values): + tensor_proto.string_val.extend([str(x) for x in proto_values]) + + def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values): + tensor_proto.bool_val.extend([np.asscalar(x) for x in proto_values]) + + _NP_TO_APPEND_FN = { + np.float32: SlowAppendFloat32ArrayToTensorProto, + np.float64: SlowAppendFloat64ArrayToTensorProto, + np.int32: SlowAppendIntArrayToTensorProto, + np.int64: SlowAppendInt64ArrayToTensorProto, + np.uint8: SlowAppendIntArrayToTensorProto, + np.int16: SlowAppendIntArrayToTensorProto, + np.int8: SlowAppendIntArrayToTensorProto, + np.complex64: SlowAppendComplexArrayToTensorProto, + np.complex128: SlowAppendComplexArrayToTensorProto, + np.object: SlowAppendObjectArrayToTensorProto, + np.bool: SlowAppendBoolArrayToTensorProto, + types.qint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto, + types.quint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto, + types.qint32.as_numpy_dtype: SlowAppendIntArrayToTensorProto, + # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16. + } + + +def GetFromNumpyDTypeDict(dtype_dict, dtype): + # NOTE: dtype_dict.get(dtype) always returns None. + for key, val in dtype_dict.iteritems(): + if key == dtype: + return val + return None + + +def GetNumpyAppendFn(dtype): + # numpy dtype for strings are variable length. We can not compare + # dtype with a single constant (np.string does not exist) to decide + # dtype is a "string" type. We need to compare the dtype.type to be + # sure it's a string type. + if dtype.type == np.string_ or dtype.type == np.unicode_: + if _FAST_TENSOR_UTIL_AVAILABLE: + return fast_tensor_util.AppendObjectArrayToTensorProto + else: + return SlowAppendObjectArrayToTensorProto + return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype) + + +def MakeTensorShapeProto(shape): + """Create a TensorShapeProto. + + Args: + shape: List of integers representing the dimensions of the tensor. + + Returns: + A TensorShapeProto. + """ + return tensor_shape_pb2.TensorShapeProto( + dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=x) for x in shape]) + + +def TensorShapeProtoToList(shape): + """Convert a TensorShape to a list. + + Args: + shape: A TensorShapeProto. + + Returns: + List of integers representing the dimensions of the tensor. + """ + return [dim.size for dim in shape.dim] + + +def _GetDenseDimensions(list_of_lists): + """Returns the inferred dense dimensions of a list of lists.""" + if not isinstance(list_of_lists, (list, tuple)): + return [] + elif not list_of_lists: + return [0] + else: + return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0]) + + +def _FlattenToStrings(nested_strings): + if isinstance(nested_strings, list): + for inner in nested_strings: + for flattened_string in _FlattenToStrings(inner): + yield flattened_string + else: + yield nested_strings + + +_TENSOR_CONTENT_TYPES = frozenset([ + types.float32, types.float64, types.int32, types.uint8, types.int16, + types.int8, types.int64 +]) + + +def _FirstNotNone(l): + for x in l: + if x is not None: + return x + return None + + +def _FilterInt(v): + if isinstance(v, (list, tuple)): + return _FirstNotNone([_FilterInt(x) for x in v]) + return None if isinstance(v, numbers.Integral) else repr(v) + + +def _FilterFloat(v): + if isinstance(v, (list, tuple)): + return _FirstNotNone([_FilterFloat(x) for x in v]) + return None if isinstance(v, numbers.Real) else repr(v) + + +def _FilterComplex(v): + if isinstance(v, (list, tuple)): + return _FirstNotNone([_FilterComplex(x) for x in v]) + return None if isinstance(v, numbers.Complex) else repr(v) + + +def _FilterStr(v): + if isinstance(v, (list, tuple)): + return _FirstNotNone([_FilterStr(x) for x in v]) + return None if isinstance(v, basestring) else repr(v) + + +def _FilterBool(v): + if isinstance(v, (list, tuple)): + return _FirstNotNone([_FilterBool(x) for x in v]) + return None if isinstance(v, bool) else repr(v) + + +def _FilterNotTensor(v): + if isinstance(v, (list, tuple)): + return _FirstNotNone([_FilterNotTensor(x) for x in v]) + return repr(v) if isinstance(v, ops.Tensor) else None + + +_TF_TO_IS_OK = { + types.float32: _FilterFloat, + types.float64: _FilterFloat, + types.int32: _FilterInt, + types.uint8: _FilterInt, + types.int16: _FilterInt, + types.int8: _FilterInt, + types.string: _FilterStr, + types.complex64: _FilterComplex, + types.int64: _FilterInt, + types.bool: _FilterBool, + types.qint32: _FilterInt, + types.quint8: _FilterInt, + types.qint8: _FilterInt, +} + + +def _AssertCompatible(values, dtype): + fn = _TF_TO_IS_OK.get(dtype, _FilterNotTensor) + mismatch = fn(values) + if mismatch is not None: + if dtype is None: + raise TypeError("List of Tensors when single Tensor expected") + else: + raise TypeError("Expected %s, got %s instead." % + (dtype.name, mismatch)) + + +def make_tensor_proto(values, dtype=None, shape=None): + """Create a TensorProto. + + Args: + values: Values to put in the TensorProto. + dtype: Optional tensor_pb2 DataType value. + shape: List of integers representing the dimensions of tensor. + + Returns: + A TensorProto. Depending on the type, it may contain data in the + "tensor_content" attribute, which is not directly useful to Python programs. + To access the values you should convert the proto back to a numpy ndarray + with tensor_util.MakeNdarray(proto). + + Raises: + TypeError: if unsupported types are provided. + ValueError: if arguments have inappropriate values. + + make_tensor_proto accepts "values" of a python scalar, a python list, a + numpy ndarray, or a numpy scalar. + + If "values" is a python scalar or a python list, make_tensor_proto + first convert it to numpy ndarray. If dtype is None, the + conversion tries its best to infer the right numpy data + type. Otherwise, the resulting numpy array has a compatible data + type with the given dtype. + + In either case above, the numpy ndarray (either the caller provided + or the auto converted) must have the compatible type with dtype. + + make_tensor_proto then converts the numpy array to a tensor proto. + + If "shape" is None, the resulting tensor proto represents the numpy + array precisely. + + Otherwise, "shape" specifies the tensor's shape and the numpy array + can not have more elements than what "shape" specifies. + + """ + if dtype: + dtype = types.as_dtype(dtype) + + # We first convert value to a numpy array or scalar. + if isinstance(values, (np.ndarray, np.generic)): + if dtype: + nparray = values.astype(dtype.as_numpy_dtype) + else: + nparray = values + else: + if values is None: + raise ValueError("None values not supported.") + # if dtype is provided, forces numpy array to be the type + # provided if possible. + np_dt = dtype.as_numpy_dtype if dtype else None + if np.prod(shape) == 0: + nparray = np.empty(shape, dtype=np_dt) + else: + _AssertCompatible(values, dtype) + nparray = np.array(values, dtype=np_dt) + if list(nparray.shape) != _GetDenseDimensions(values): + raise ValueError("Argument must be a dense tensor: %s" % values) + # python/numpy default float type is float64. We prefer float32 instead. + if (nparray.dtype == np.float64) and dtype is None: + nparray = nparray.astype(np.float32) + # python/numpy default int type is int64. We prefer int32 instead. + elif (nparray.dtype == np.int64) and dtype is None: + nparray = nparray.astype(np.int32) + + # if dtype is provided, it must be compatible with what numpy + # conversion says. + numpy_dtype = types.as_dtype(nparray.dtype) + if numpy_dtype is None: + raise TypeError("Unrecognized data type: %s" % nparray.dtype) + + # If dtype was specified and is a quantized type, we convert + # numpy_dtype back into the quantized version. + if dtype in [types.qint8, types.quint8, types.qint32]: + numpy_dtype = dtype + + if dtype is not None and not dtype.base_dtype == numpy_dtype.base_dtype: + raise TypeError("Incompatible types: %s vs. %s" % (dtype, nparray.dtype)) + + # If shape is not given, get the shape from the numpy array. + if shape is None: + shape = nparray.shape + is_same_size = True + shape_size = nparray.size + else: + shape = [int(dim) for dim in shape] + shape_size = np.prod(shape) + is_same_size = shape_size == nparray.size + + if nparray.size > shape_size: + raise ValueError( + "Too many elements provided. Needed at most %d, but received %d" % + (shape_size, nparray.size)) + + tensor_proto = tensor_pb2.TensorProto( + dtype=numpy_dtype.as_datatype_enum, + tensor_shape=MakeTensorShapeProto(shape)) + + if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1: + tensor_proto.tensor_content = nparray.tostring() + return tensor_proto + + # If we were not given values as a numpy array, compute the proto_values + # from the given values directly, to avoid numpy trimming nulls from the + # strings. Since values could be a list of strings, or a multi-dimensional + # list of lists that might or might not correspond to the given shape, + # we flatten it conservatively. + if numpy_dtype == types.string and not isinstance(values, np.ndarray): + proto_values = _FlattenToStrings(values) + tensor_proto.string_val.extend([str(x) for x in proto_values]) + return tensor_proto + + # TensorFlow expects C order (a.k.a., eigen row major). + proto_values = nparray.ravel() + + append_fn = GetNumpyAppendFn(proto_values.dtype) + if append_fn is None: + raise TypeError("Element type not supported in TensorProto: %s" % + numpy_dtype.name) + append_fn(tensor_proto, proto_values) + + return tensor_proto + + +def MakeNdarray(tensor): + """Create a numpy ndarray from a tensor. + + Create a numpy ndarray with the same shape and data as the tensor. + + Args: + tensor: A TensorProto. + + Returns: + A numpy array with the tensor contents. + + Raises: + TypeError: if tensor has unsupported type. + + """ + shape = [d.size for d in tensor.tensor_shape.dim] + num_elements = np.prod(shape) + tensor_dtype = types.as_dtype(tensor.dtype) + dtype = tensor_dtype.as_numpy_dtype + + if tensor.tensor_content: + return np.fromstring(tensor.tensor_content, dtype=dtype).reshape(shape) + elif tensor_dtype == types.float32: + if len(tensor.float_val) == 1: + return np.repeat(np.array(tensor.float_val[0], dtype=dtype), + num_elements).reshape(shape) + else: + return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape) + elif tensor_dtype == types.float64: + if len(tensor.double_val) == 1: + return np.repeat(np.array(tensor.double_val[0], dtype=dtype), + num_elements).reshape(shape) + else: + return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape) + elif tensor_dtype in [types.int32, types.uint8, types.int16, types.int8, + types.qint32, types.quint8, types.qint8, + types.bfloat16]: + if len(tensor.int_val) == 1: + return np.repeat(np.array(tensor.int_val[0], dtype=dtype), + num_elements).reshape(shape) + else: + return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape) + elif tensor_dtype == types.int64: + if len(tensor.int64_val) == 1: + return np.repeat(np.array(tensor.int64_val[0], dtype=dtype), + num_elements).reshape(shape) + else: + return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape) + elif tensor_dtype == types.string: + if len(tensor.string_val) == 1: + return np.repeat(np.array(str(tensor.string_val[0]), dtype=dtype), + num_elements).reshape(shape) + else: + return np.array([str(x) for x in tensor.string_val], + dtype=dtype).reshape(shape) + elif tensor_dtype == types.complex64: + it = iter(tensor.scomplex_val) + if len(tensor.scomplex_val) == 2: + return np.repeat(np.array(complex(tensor.scomplex_val[0], + tensor.scomplex_val[1]), dtype=dtype), + num_elements).reshape(shape) + else: + return np.array([complex(x[0], x[1]) for x in zip(it, it)], + dtype=dtype).reshape(shape) + elif tensor_dtype == types.bool: + if len(tensor.bool_val) == 1: + return np.repeat(np.array(tensor.bool_val[0], dtype=dtype), + num_elements).reshape(shape) + else: + return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape) + else: + raise TypeError("Unsupported tensor type: %s" % tensor.dtype) + + +def ShapeEquals(tensor_proto, shape): + """Returns True if "tensor_proto" has the given "shape". + + Args: + tensor_proto: A TensorProto. + shape: A tensor shape, expressed as a TensorShape, list, or tuple. + + Returns: + True if "tensor_proto" has the given "shape", otherwise False. + + Raises: + TypeError: If "tensor_proto" is not a TensorProto, or shape is not a + TensorShape, list, or tuple. + """ + if not isinstance(tensor_proto, tensor_pb2.TensorProto): + raise TypeError("tensor_proto is not a tensor_pb2.TensorProto object") + if isinstance(shape, tensor_shape_pb2.TensorShapeProto): + shape = [d.size for d in shape.dim] + elif not isinstance(shape, (list, tuple)): + raise TypeError("shape is not a list or tuple") + tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim] + return all(x == y for x, y in zip(tensor_shape_list, shape)) + + +def ConstantValue(tensor): + """Returns the constant value of the given tensor, if efficiently calculable. + + This function attempts to partially evaluate the given tensor, and + returns its value as a numpy ndarray if this succeeds. + + TODO(mrry): Consider whether this function should use a registration + mechanism like gradients and ShapeFunctions, so that it is easily + extensible. + + Args: + tensor: The Tensor to be evaluated. + + Returns: + A numpy ndarray containing the constant value of the given `tensor`, + or None if it cannot be calculated. + + Raises: + TypeError: if tensor is not an ops.Tensor. + """ + # TODO(mdevin): Support Variables? + if not isinstance(tensor, ops.Tensor): + raise TypeError("tensor is not a Tensor") + if tensor.op.type == "Const": + return MakeNdarray(tensor.op.get_attr("value")) + elif tensor.op.type == "Shape": + input_shape = tensor.op.inputs[0].get_shape() + if input_shape.is_fully_defined(): + return np.array([dim.value for dim in input_shape.dims]) + else: + return None + elif tensor.op.type == "Size": + input_shape = tensor.op.inputs[0].get_shape() + if input_shape.is_fully_defined(): + return np.array([np.prod([dim.value for dim in input_shape.dims])]) + else: + return None + elif tensor.op.type == "Rank": + input_shape = tensor.op.inputs[0].get_shape() + if input_shape.ndims is not None: + return np.array([input_shape.ndims]) + else: + return None + elif tensor.op.type == "Range": + start = ConstantValue(tensor.op.inputs[0]) + if start is None: + return None + limit = ConstantValue(tensor.op.inputs[1]) + if limit is None: + return None + delta = ConstantValue(tensor.op.inputs[2]) + if delta is None: + return None + return np.array(range(start, limit, delta), + dtype=tensor.dtype.as_numpy_dtype) + else: + return None diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py new file mode 100644 index 0000000000..7c1c0b8d3e --- /dev/null +++ b/tensorflow/python/framework/tensor_util_test.py @@ -0,0 +1,379 @@ +"""Functional tests for tensor_util.""" +import tensorflow.python.platform + +import numpy as np + +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import test_util +from tensorflow.python.framework import types +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import constant_op +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import googletest + + +class TensorUtilTest(test_util.TensorFlowTestCase): + + def testFloat(self): + t = tensor_util.make_tensor_proto(10.0) + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape {} + float_val: 10.0 + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.float32, a.dtype) + self.assertAllClose(np.array(10.0, dtype=np.float32), a) + + def testFloatN(self): + t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0]) + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape { dim { size: 3 } } + tensor_content: "\000\000 A\000\000\240A\000\000\360A" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.float32, a.dtype) + self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) + + def testFloatTyped(self): + t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=types.float32) + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape { dim { size: 3 } } + tensor_content: "\000\000 A\000\000\240A\000\000\360A" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.float32, a.dtype) + self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) + + def testFloatTypeCoerce(self): + t = tensor_util.make_tensor_proto([10, 20, 30], dtype=types.float32) + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape { dim { size: 3 } } + tensor_content: "\000\000 A\000\000\240A\000\000\360A" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.float32, a.dtype) + self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) + + def testFloatTypeCoerceNdarray(self): + arr = np.asarray([10, 20, 30], dtype="int") + t = tensor_util.make_tensor_proto(arr, dtype=types.float32) + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape { dim { size: 3 } } + tensor_content: "\000\000 A\000\000\240A\000\000\360A" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.float32, a.dtype) + self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) + + def testFloatSizes(self): + t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[1, 3]) + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape { dim { size: 1 } dim { size: 3 } } + tensor_content: "\000\000 A\000\000\240A\000\000\360A" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.float32, a.dtype) + self.assertAllClose(np.array([[10.0, 20.0, 30.0]], dtype=np.float32), a) + + def testFloatSizes2(self): + t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[3, 1]) + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape { dim { size: 3 } dim { size: 1 } } + tensor_content: "\000\000 A\000\000\240A\000\000\360A" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.float32, a.dtype) + self.assertAllClose(np.array([[10.0], [20.0], [30.0]], dtype=np.float32), + a) + + def testFloatSizesLessValues(self): + t = tensor_util.make_tensor_proto(10.0, shape=[1, 3]) + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape { dim { size: 1 } dim { size: 3 } } + float_val: 10.0 + """, t) + # No conversion to Ndarray for this one: not enough values. + + def testFloatNpArrayFloat64(self): + t = tensor_util.make_tensor_proto( + np.array([[10.0, 20.0, 30.0]], dtype=np.float64)) + self.assertProtoEquals(""" + dtype: DT_DOUBLE + tensor_shape { dim { size: 1 } dim { size: 3 } } + tensor_content: "\000\000\000\000\000\000$@\000\000\000\000\000\0004@\000\000\000\000\000\000>@" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.float64, a.dtype) + self.assertAllClose(np.array([[10.0, 20.0, 30.0]], dtype=np.float64), + tensor_util.MakeNdarray(t)) + + def testFloatTypesWithImplicitRepeat(self): + for dtype, nptype in [ + (types.float32, np.float32), (types.float64, np.float64)]: + t = tensor_util.make_tensor_proto([10.0], shape=[3, 4], dtype=dtype) + a = tensor_util.MakeNdarray(t) + self.assertAllClose(np.array([[10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0], + [10.0, 10.0, 10.0, 10.0]], dtype=nptype), a) + + def testInt(self): + t = tensor_util.make_tensor_proto(10) + self.assertProtoEquals(""" + dtype: DT_INT32 + tensor_shape {} + int_val: 10 + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.int32, a.dtype) + self.assertAllClose(np.array(10, dtype=np.int32), a) + + def testIntNDefaultType(self): + t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2]) + self.assertProtoEquals(""" + dtype: DT_INT32 + tensor_shape { dim { size: 2 } dim { size: 2 } } + tensor_content: "\\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.int32, a.dtype) + self.assertAllClose(np.array([[10, 20], [30, 40]], dtype=np.int32), a) + + def testIntTypes(self): + for dtype, nptype in [ + (types.int32, np.int32), + (types.uint8, np.uint8), + (types.int16, np.int16), + (types.int8, np.int8)]: + # Test with array. + t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype) + self.assertEquals(dtype, t.dtype) + self.assertProtoEquals("dim { size: 3 }", t.tensor_shape) + a = tensor_util.MakeNdarray(t) + self.assertEquals(nptype, a.dtype) + self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a) + # Test with ndarray. + t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype)) + self.assertEquals(dtype, t.dtype) + self.assertProtoEquals("dim { size: 3 }", t.tensor_shape) + a = tensor_util.MakeNdarray(t) + self.assertEquals(nptype, a.dtype) + self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a) + + def testIntTypesWithImplicitRepeat(self): + for dtype, nptype in [ + (types.int64, np.int64), + (types.int32, np.int32), + (types.uint8, np.uint8), + (types.int16, np.int16), + (types.int8, np.int8)]: + t = tensor_util.make_tensor_proto([10], shape=[3, 4], dtype=dtype) + a = tensor_util.MakeNdarray(t) + self.assertAllEqual(np.array([[10, 10, 10, 10], + [10, 10, 10, 10], + [10, 10, 10, 10]], dtype=nptype), a) + + def testLong(self): + t = tensor_util.make_tensor_proto(10, dtype=types.int64) + self.assertProtoEquals(""" + dtype: DT_INT64 + tensor_shape {} + int64_val: 10 + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.int64, a.dtype) + self.assertAllClose(np.array(10, dtype=np.int64), a) + + def testLongN(self): + t = tensor_util.make_tensor_proto([10, 20, 30], shape=[1, 3], + dtype=types.int64) + self.assertProtoEquals(""" + dtype: DT_INT64 + tensor_shape { dim { size: 1 } dim { size: 3 } } + tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.int64, a.dtype) + self.assertAllClose(np.array([[10, 20, 30]], dtype=np.int64), a) + + def testLongNpArray(self): + t = tensor_util.make_tensor_proto(np.array([10, 20, 30])) + self.assertProtoEquals(""" + dtype: DT_INT64 + tensor_shape { dim { size: 3 } } + tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.int64, a.dtype) + self.assertAllClose(np.array([10, 20, 30], dtype=np.int64), a) + + def testString(self): + t = tensor_util.make_tensor_proto("foo") + self.assertProtoEquals(""" + dtype: DT_STRING + tensor_shape {} + string_val: "foo" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.object, a.dtype) + self.assertEquals(["foo"], a) + + def testStringWithImplicitRepeat(self): + t = tensor_util.make_tensor_proto("f", shape=[3, 4]) + a = tensor_util.MakeNdarray(t) + self.assertAllEqual(np.array([["f", "f", "f", "f"], + ["f", "f", "f", "f"], + ["f", "f", "f", "f"]], dtype=np.object), a) + + def testStringN(self): + t = tensor_util.make_tensor_proto(["foo", "bar", "baz"], shape=[1, 3]) + self.assertProtoEquals(""" + dtype: DT_STRING + tensor_shape { dim { size: 1 } dim { size: 3 } } + string_val: "foo" + string_val: "bar" + string_val: "baz" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.object, a.dtype) + self.assertAllEqual(np.array([["foo", "bar", "baz"]]), a) + + def testStringNpArray(self): + t = tensor_util.make_tensor_proto(np.array([["a", "ab"], ["abc", "abcd"]])) + self.assertProtoEquals(""" + dtype: DT_STRING + tensor_shape { dim { size: 2 } dim { size: 2 } } + string_val: "a" + string_val: "ab" + string_val: "abc" + string_val: "abcd" + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.object, a.dtype) + self.assertAllEqual(np.array([["a", "ab"], ["abc", "abcd"]]), a) + + def testComplex(self): + t = tensor_util.make_tensor_proto((1+2j), dtype=types.complex64) + self.assertProtoEquals(""" + dtype: DT_COMPLEX64 + tensor_shape {} + scomplex_val: 1 + scomplex_val: 2 + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.complex64, a.dtype) + self.assertAllEqual(np.array(1 + 2j), a) + + def testComplexWithImplicitRepeat(self): + t = tensor_util.make_tensor_proto((1+1j), shape=[3, 4], + dtype=types.complex64) + a = tensor_util.MakeNdarray(t) + self.assertAllClose(np.array([[(1+1j), (1+1j), (1+1j), (1+1j)], + [(1+1j), (1+1j), (1+1j), (1+1j)], + [(1+1j), (1+1j), (1+1j), (1+1j)]], + dtype=np.complex64), a) + + def testComplexN(self): + t = tensor_util.make_tensor_proto([(1+2j), (3+4j), (5+6j)], shape=[1, 3], + dtype=types.complex64) + self.assertProtoEquals(""" + dtype: DT_COMPLEX64 + tensor_shape { dim { size: 1 } dim { size: 3 } } + scomplex_val: 1 + scomplex_val: 2 + scomplex_val: 3 + scomplex_val: 4 + scomplex_val: 5 + scomplex_val: 6 + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.complex64, a.dtype) + self.assertAllEqual(np.array([[(1+2j), (3+4j), (5+6j)]]), a) + + def testComplexNpArray(self): + t = tensor_util.make_tensor_proto( + np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), dtype=types.complex64) + # scomplex_val are real_0, imag_0, real_1, imag_1, ... + self.assertProtoEquals(""" + dtype: DT_COMPLEX64 + tensor_shape { dim { size: 2 } dim { size: 2 } } + scomplex_val: 1 + scomplex_val: 2 + scomplex_val: 3 + scomplex_val: 4 + scomplex_val: 5 + scomplex_val: 6 + scomplex_val: 7 + scomplex_val: 8 + """, t) + a = tensor_util.MakeNdarray(t) + self.assertEquals(np.complex64, a.dtype) + self.assertAllEqual(np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), a) + + def testUnsupportedDType(self): + with self.assertRaises(TypeError): + tensor_util.make_tensor_proto(np.array([1]), 0) + + def testShapeTooLarge(self): + with self.assertRaises(ValueError): + tensor_util.make_tensor_proto(np.array([1, 2]), shape=[1]) + + def testLowRankSupported(self): + t = tensor_util.make_tensor_proto(np.array(7)) + self.assertProtoEquals(""" + dtype: DT_INT64 + tensor_shape {} + int64_val: 7 + """, t) + + def testShapeEquals(self): + t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2]) + self.assertTrue(tensor_util.ShapeEquals(t, [2, 2])) + self.assertTrue(tensor_util.ShapeEquals(t, (2, 2))) + self.assertTrue( + tensor_util.ShapeEquals(t, tensor_util.MakeTensorShapeProto([2, 2]))) + self.assertFalse(tensor_util.ShapeEquals(t, [5, 3])) + self.assertFalse(tensor_util.ShapeEquals(t, [1, 4])) + self.assertFalse(tensor_util.ShapeEquals(t, [4])) + + +class ConstantValueTest(test_util.TensorFlowTestCase): + + def testConstant(self): + np_val = np.random.rand(3, 4, 7).astype(np.float32) + tf_val = constant_op.constant(np_val) + self.assertAllClose(np_val, tensor_util.ConstantValue(tf_val)) + + np_val = np.random.rand(3, 0, 7).astype(np.float32) + tf_val = constant_op.constant(np_val) + self.assertAllClose(np_val, tensor_util.ConstantValue(tf_val)) + + def testUnknown(self): + tf_val = state_ops.variable_op(shape=[3, 4, 7], dtype=types.float32) + self.assertIs(None, tensor_util.ConstantValue(tf_val)) + + def testShape(self): + np_val = np.array([1, 2, 3]) + tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])) + self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val)) + + def testSize(self): + np_val = np.array([6]) + tf_val = array_ops.size(constant_op.constant(0.0, shape=[1, 2, 3])) + self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val)) + + def testRank(self): + np_val = np.array([3]) + tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3])) + self.assertAllEqual(np_val, tensor_util.ConstantValue(tf_val)) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/test_kernel_label_op.cc b/tensorflow/python/framework/test_kernel_label_op.cc new file mode 100644 index 0000000000..50f8522e1b --- /dev/null +++ b/tensorflow/python/framework/test_kernel_label_op.cc @@ -0,0 +1,47 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +REGISTER_OP("KernelLabel").Output("result: string"); + +namespace { +enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL }; +} // namespace + +template <KernelLabel KL> +class KernelLabelOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* ctx) override { + Tensor* output; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("result", TensorShape({}), &output)); + switch (KL) { + case DEFAULT_LABEL: + output->scalar<string>()() = "My label is: default"; + break; + case OVERLOAD_1_LABEL: + output->scalar<string>()() = "My label is: overload_1"; + break; + case OVERLOAD_2_LABEL: + output->scalar<string>()() = "My label is: overload_2"; + break; + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("KernelLabel").Device(DEVICE_CPU), + KernelLabelOp<DEFAULT_LABEL>); +REGISTER_KERNEL_BUILDER(Name("KernelLabel") + .Device(DEVICE_CPU) + .Label("overload_1"), + KernelLabelOp<OVERLOAD_1_LABEL>); +REGISTER_KERNEL_BUILDER(Name("KernelLabel") + .Device(DEVICE_CPU) + .Label("overload_2"), + KernelLabelOp<OVERLOAD_2_LABEL>); + +} // end namespace tensorflow diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py new file mode 100644 index 0000000000..597a5ad829 --- /dev/null +++ b/tensorflow/python/framework/test_util.py @@ -0,0 +1,437 @@ +# pylint: disable=invalid-name +"""Test utils for tensorflow.""" +import contextlib +import math +import re +import threading + +import tensorflow.python.platform + +import numpy as np + +from google.protobuf import text_format + +from tensorflow.core.framework import config_pb2 +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.client import graph_util +from tensorflow.python.client import session +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.platform import googletest +from tensorflow.python.platform import logging +from tensorflow.python.util.protobuf import compare + + +def IsGoogleCudaEnabled(): + return pywrap_tensorflow.IsGoogleCudaEnabled() + + +class TensorFlowTestCase(googletest.TestCase): + """Root class for tests that need to test tensor flow. + """ + + def __init__(self, methodName="runTest"): + super(TensorFlowTestCase, self).__init__(methodName) + self._threads = [] + self._tempdir = None + self._cached_session = None + + def setUp(self): + self._ClearCachedSession() + ops.reset_default_graph() + + def tearDown(self): + for thread in self._threads: + self.assertFalse(thread.is_alive(), "A checkedThread did not terminate") + self._ClearCachedSession() + + def _ClearCachedSession(self): + if self._cached_session is not None: + self._cached_session.close() + self._cached_session = None + + def get_temp_dir(self): + if not self._tempdir: + self._tempdir = googletest.GetTempDir() + return self._tempdir + + def _AssertProtoEquals(self, a, b): + """Asserts that a and b are the same proto. + + Uses Proto2Cmp() first, as it returns correct results + for floating point attributes, and then use assertProto2Equal() + in case of failure as it provides good error messages. + + Args: + a: a proto. + b: another proto. + """ + if compare.Proto2Cmp(a, b) != 0: + compare.assertProto2Equal(self, a, b, normalize_numbers=True) + + def assertProtoEquals(self, expected_message_maybe_ascii, message): + """Asserts that message is same as parsed expected_message_ascii. + + Creates another prototype of message, reads the ascii message into it and + then compares them using self._AssertProtoEqual(). + + Args: + expected_message_maybe_ascii: proto message in original or ascii form + message: the message to validate + """ + + if type(expected_message_maybe_ascii) == type(message): + expected_message = expected_message_maybe_ascii + self._AssertProtoEquals(expected_message, message) + elif isinstance(expected_message_maybe_ascii, str): + expected_message = type(message)() + text_format.Merge(expected_message_maybe_ascii, expected_message) + self._AssertProtoEquals(expected_message, message) + else: + assert False, ("Can't compare protos of type " + + type(expected_message_maybe_ascii) + " and " + + type(message)) + + def assertStartsWith(self, actual, expected_start, msg=None): + """Assert that actual.startswith(expected_start) is True. + + Args: + actual: str + expected_start: str + msg: Optional message to report on failure. + """ + if not actual.startswith(expected_start): + fail_msg = "%r does not start with %r" % (actual, expected_start) + fail_msg += " : %r" % (msg) if msg else "" + self.fail(fail_msg) + + # pylint: disable=g-doc-return-or-yield + @contextlib.contextmanager + def test_session(self, + graph=None, + config=None, + use_gpu=False, + force_gpu=False): + """Returns a TensorFlow Session for use in executing tests. + + This method should be used for all functional tests. + + Use the `use_gpu` and `force_gpu` options to control where ops are run. If + `force_gpu` is True, all ops are pinned to `/gpu:0`. Otherwise, if `use_gpu` + is True, TensorFlow tries to run as many ops on the GPU as possible. If both + `force_gpu and `use_gpu` are False, all ops are pinned to the CPU. + + Example: + + class MyOperatorTest(test_util.TensorFlowTestCase): + def testMyOperator(self): + with self.test_session(use_gpu=True): + valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] + result = MyOperator(valid_input).eval() + self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] + invalid_input = [-1.0, 2.0, 7.0] + with self.assertRaisesOpError("negative input not supported"): + MyOperator(invalid_input).eval() + + Args: + graph: Optional graph to use during the returned session. + config: An optional config_pb2.ConfigProto to use to configure the + session. + use_gpu: If True, attempt to run as many ops as possible on GPU. + force_gpu: If True, pin all ops to `/gpu:0`. + + Returns: + A Session object that should be used as a context manager to surround + the graph building and execution code in a test case. + """ + def prepare_config(config): + if config is None: + config = config_pb2.ConfigProto() + config.allow_soft_placement = not force_gpu + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + elif force_gpu and config.allow_soft_placement: + config = config_pb2.ConfigProto().CopyFrom(config) + config.allow_soft_placement = False + return config + + if graph is None: + if self._cached_session is None: + self._cached_session = session.Session(graph=None, + config=prepare_config(config)) + sess = self._cached_session + with sess.graph.as_default(), sess.as_default(): + if force_gpu: + with sess.graph.device("/gpu:0"): + yield sess + elif use_gpu: + yield sess + else: + with sess.graph.device(graph_util.pin_to_cpu): + yield sess + else: + with session.Session(graph=graph, config=prepare_config(config)) as sess: + if force_gpu: + with sess.graph.device("/gpu:0"): + yield sess + elif use_gpu: + yield sess + else: + with sess.graph.device(graph_util.pin_to_cpu): + yield sess + # pylint: enable=g-doc-return-or-yield + + class _CheckedThread(object): + """A wrapper class for Thread that asserts successful completion. + + This class should be created using the TensorFlowTestCase.checkedThread() + method. + """ + + def __init__(self, testcase, target, args=None, kwargs=None): + """Constructs a new instance of _CheckedThread. + + Args: + testcase: The TensorFlowTestCase for which this thread is being created. + target: A callable object representing the code to be executed in the + thread. + args: A tuple of positional arguments that will be passed to target. + kwargs: A dictionary of keyword arguments that will be passed to target. + """ + self._testcase = testcase + self._target = target + self._args = () if args is None else args + self._kwargs = {} if kwargs is None else kwargs + self._thread = threading.Thread(target=self._protected_run) + self._exception = None + + def _protected_run(self): + """Target for the wrapper thread. Sets self._exception on failure.""" + try: + self._target(*self._args, **self._kwargs) +# pylint: disable=broad-except + except Exception as e: + # pylint: enable=broad-except + self._exception = e + + def start(self): + """Starts the thread's activity. + + This must be called at most once per _CheckedThread object. It arranges + for the object's target to be invoked in a separate thread of control. + """ + self._thread.start() + + def join(self): + """Blocks until the thread terminates. + + Raises: + self._testcase.failureException: If the thread terminates with due to + an exception. + """ + self._thread.join() + if self._exception is not None: + self._testcase.fail( + "Error in checkedThread: %s" % str(self._exception)) + + def is_alive(self): + """Returns whether the thread is alive. + + This method returns True just before the run() method starts + until just after the run() method terminates. + + Returns: + True if the thread is alive, otherwise False. + """ + return self._thread.is_alive() + + def checkedThread(self, target, args=None, kwargs=None): + """Returns a Thread wrapper that asserts 'target' completes successfully. + + This method should be used to create all threads in test cases, as + otherwise there is a risk that a thread will silently fail, and/or + assertions made in the thread will not be respected. + + Args: + target: A callable object to be executed in the thread. + args: The argument tuple for the target invocation. Defaults to (). + kwargs: A dictionary of keyword arguments for the target invocation. + Defaults to {}. + + Returns: + A wrapper for threading.Thread that supports start() and join() methods. + """ + ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs) + self._threads.append(ret) + return ret +# pylint: enable=invalid-name + + def assertNear(self, f1, f2, err): + """Asserts that two floats are near each other. + + Checks that |f1 - f2| < err and asserts a test failure + if not. + + Args: + f1: a float value. + f2: a float value. + err: a float value. + """ + self.assertTrue(math.fabs(f1 - f2) < err) + + def assertArrayNear(self, farray1, farray2, err): + """Asserts that two float arrays are near each other. + + Checks that for all elements of farray1 and farray2 + |f1 - f2| < err. Asserts a test failure if not. + + Args: + farray1: a list of float values. + farray2: a list of float values. + err: a float value. + """ + for f1, f2 in zip(farray1, farray2): + self.assertNear(f1, f2, err) + + def _NDArrayNear(self, ndarray1, ndarray2, err): + return np.linalg.norm(ndarray1 - ndarray2) < err + + def assertNDArrayNear(self, ndarray1, ndarray2, err): + """Asserts that two numpy arrays have near values. + + Args: + ndarray1: a numpy ndarray. + ndarray2: a numpy ndarray. + err: a float. The maximum absolute difference allowed. + """ + self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err)) + + def _GetNdArray(self, a): + if not isinstance(a, np.ndarray): + a = np.array(a) + return a + + def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6): + """Asserts that two numpy arrays have near values. + + Args: + a: a numpy ndarray or anything can be converted to one. + b: a numpy ndarray or anything can be converted to one. + rtol: relative tolerance + atol: absolute tolerance + """ + a = self._GetNdArray(a) + b = self._GetNdArray(b) + self.assertEqual( + a.shape, b.shape, + "Shape mismatch: expected %s, got %s." % (a.shape, b.shape)) + if not np.allclose(a, b, rtol=rtol, atol=atol): + # Prints more details than np.testing.assert_allclose. + # + # NOTE: numpy.allclose (and numpy.testing.assert_allclose) + # checks whether two arrays are element-wise equal within a + # tolerance. The relative difference (rtol * abs(b)) and the + # absolute difference atol are added together to compare against + # the absolute difference between a and b. Here, we want to + # print out which elements violate such conditions. + cond = np.abs(a - b) > atol + rtol * np.abs(b) + if a.ndim: + x = a[np.where(cond)] + y = b[np.where(cond)] + print "not close where = ", np.where(cond) + else: + # np.where is broken for scalars + x, y = a, b + print "not close lhs = ", x + print "not close rhs = ", y + print "not close dif = ", np.abs(x - y) + print "not close tol = ", atol + rtol * np.abs(y) + np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + + def assertAllEqual(self, a, b): + """Asserts that two numpy arrays have the same values. + + Args: + a: a numpy ndarray or anything can be converted to one. + b: a numpy ndarray or anything can be converted to one. + """ + a = self._GetNdArray(a) + b = self._GetNdArray(b) + self.assertEqual( + a.shape, b.shape, + "Shape mismatch: expected %s, got %s." % (a.shape, b.shape)) + same = (a == b) + + if a.dtype == np.float32 or a.dtype == np.float64: + same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b))) + if not np.all(same): + # Prints more details than np.testing.assert_array_equal. + diff = np.logical_not(same) + if a.ndim: + x = a[np.where(diff)] + y = b[np.where(diff)] + print "not equal where = ", np.where(diff) + else: + # np.where is broken for scalars + x, y = a, b + print "not equal lhs = ", x + print "not equal rhs = ", y + np.testing.assert_array_equal(a, b) + + # pylint: disable=g-doc-return-or-yield + @contextlib.contextmanager + def assertRaisesWithPredicateMatch(self, exception_type, + expected_err_re_or_predicate): + """Returns a context manager to enclose code expected to raise an exception. + + Args: + exception_type: The expected type of exception that should be raised. + expected_err_re_or_predicate: If this is callable, it should be a function + of one argument that inspects the passed-in OpError exception and + returns True (success) or False (please fail the test). Otherwise, the + error message is expected to match this regular expression partially. + + Returns: + A context manager to surround code that is expected to raise an + errors.OpError exception. + """ + if callable(expected_err_re_or_predicate): + predicate = expected_err_re_or_predicate + else: + def predicate(e): + err_str = e.message + op = e.op + while op is not None: + err_str += "\nCaused by: " + op.name + op = op._original_op + logging.info("Searching within error strings: '%s' within '%s'", + expected_err_re_or_predicate, err_str) + return re.search(expected_err_re_or_predicate, err_str) + try: + yield + self.fail(exception_type.__name__ + " not raised") +# pylint: disable=broad-except + except Exception as e: + # pylint: enable=broad-except + if not isinstance(e, exception_type) or not predicate(e): + raise AssertionError(e) + # pylint: enable=g-doc-return-or-yield + + def assertRaisesOpError(self, expected_err_re_or_predicate): + return self.assertRaisesWithPredicateMatch(errors.OpError, + expected_err_re_or_predicate) + + def assertShapeEqual(self, np_array, tf_tensor): + """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape. + + Args: + np_array: A Numpy ndarray or Numpy scalar. + tf_tensor: A Tensor. + + Raises: + TypeError: If the arguments have the wrong type. + """ + if not isinstance(np_array, (np.ndarray, np.generic)): + raise TypeError("np_array must be a Numpy ndarray or Numpy scalar") + if not isinstance(tf_tensor, ops.Tensor): + raise TypeError("tf_tensor must be a Tensor") + self.assertAllEqual(np_array.shape, tf_tensor.get_shape().as_list()) diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py new file mode 100644 index 0000000000..e0618cfea4 --- /dev/null +++ b/tensorflow/python/framework/test_util_test.py @@ -0,0 +1,128 @@ +"""Tests for tensorflow.ops.test_util.""" +import threading + +import tensorflow.python.platform +import numpy as np + +from google.protobuf import text_format + +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.framework import types +from tensorflow.python.platform import googletest +from tensorflow.python.ops import logging_ops + +class TestUtilTest(test_util.TensorFlowTestCase): + + def testIsGoogleCudaEnabled(self): + # The test doesn't assert anything. It ensures the py wrapper + # function is generated correctly. + if test_util.IsGoogleCudaEnabled(): + print "GoogleCuda is enabled" + else: + print "GoogleCuda is disabled" + + def testAssertProtoEqualsStr(self): + + graph_str = "node { name: 'w1' op: 'params' }" + graph_def = graph_pb2.GraphDef() + text_format.Merge(graph_str, graph_def) + + # test string based comparison + self.assertProtoEquals(graph_str, graph_def) + + # test original comparison + self.assertProtoEquals(graph_def, graph_def) + + def testNDArrayNear(self): + a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + a3 = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]) + self.assertTrue(self._NDArrayNear(a1, a2, 1e-5)) + self.assertFalse(self._NDArrayNear(a1, a3, 1e-5)) + + def testCheckedThreadSucceeds(self): + def noop(ev): + ev.set() + + event_arg = threading.Event() + + self.assertFalse(event_arg.is_set()) + t = self.checkedThread(target=noop, args=(event_arg,)) + t.start() + t.join() + self.assertTrue(event_arg.is_set()) + + def testCheckedThreadFails(self): + def err_func(): + return 1 / 0 + + t = self.checkedThread(target=err_func) + t.start() + with self.assertRaises(self.failureException) as fe: + t.join() + self.assertTrue("integer division or modulo by zero" + in fe.exception.message) + + def testCheckedThreadWithWrongAssertionFails(self): + x = 37 + + def err_func(): + self.assertTrue(x < 10) + + t = self.checkedThread(target=err_func) + t.start() + with self.assertRaises(self.failureException) as fe: + t.join() + self.assertTrue("False is not true" in fe.exception.message) + + def testMultipleThreadsWithOneFailure(self): + def err_func(i): + self.assertTrue(i != 7) + + threads = [self.checkedThread(target=err_func, args=(i,)) + for i in range(10)] + for t in threads: + t.start() + for i, t in enumerate(threads): + if i == 7: + with self.assertRaises(self.failureException): + t.join() + else: + t.join() + + def _WeMustGoDeeper(self, msg): + with self.assertRaisesOpError(msg): + node_def = ops._NodeDef("op_type", "name") + node_def_orig = ops._NodeDef("op_type_orig", "orig") + op_orig = ops.Operation(node_def_orig, ops.get_default_graph()) + op = ops.Operation(node_def, ops.get_default_graph(), original_op=op_orig) + raise errors.UnauthenticatedError(node_def, op, "true_err") + + def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self): + with self.assertRaises(AssertionError): + self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for") + + self._WeMustGoDeeper("true_err") + self._WeMustGoDeeper("name") + self._WeMustGoDeeper("orig") + + def testAllCloseScalars(self): + self.assertAllClose(7, 7 + 1e-8) + with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"): + self.assertAllClose(7, 8) + + def testForceGPU(self): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Cannot assign a device to node"): + with self.test_session(force_gpu=True): + # this relies on us not having a GPU implementation for assert, which + # seems sensible + x = [True] + y = [15] + logging_ops.Assert(x, y).run() + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/types.py b/tensorflow/python/framework/types.py new file mode 100644 index 0000000000..6a8c629fe4 --- /dev/null +++ b/tensorflow/python/framework/types.py @@ -0,0 +1,418 @@ +"""Library of dtypes (Tensor element types).""" +import tensorflow.python.platform + +import numpy as np + +from tensorflow.core.framework import types_pb2 + + +class DType(object): + """Represents the type of the elements in a `Tensor`. + + The following `DType` objects are defined: + + * `tf.float32`: 32-bit single-precision floating-point. + * `tf.float64`: 64-bit double-precision floating-point. + * `tf.bfloat16`: 16-bit truncated floating-point. + * `tf.complex64`: 64-bit single-precision complex. + + * `tf.int8`: 8-bit signed integer. + * `tf.uint8`: 8-bit unsigned integer. + * `tf.int32`: 32-bit signed integer. + * `tf.int64`: 64-bit signed integer. + + * `tf.bool`: Boolean. + + * `tf.string`: String. + + * `tf.qint8`: Quantized 8-bit signed integer. + * `tf.quint8`: Quantized 8-bit unsigned integer. + * `tf.qint32`: Quantized 32-bit signed integer. + + In addition, variants of these types with the `_ref` suffix are + defined for reference-typed tensors. + + The `tf.as_dtype()` function converts numpy types and string type + names to a `DType` object. + + @@is_compatible_with + @@name + @@base_dtype + @@is_ref_dtype + @@as_ref + @@is_integer + @@is_quantized + + @@as_numpy_dtype + @@as_datatype_enum + """ + + def __init__(self, type_enum): + """Creates a new `DataType`. + + NOTE(mrry): In normal circumstances, you should not need to + construct a DataType object directly. Instead, use the + types.as_dtype() function. + + Args: + type_enum: A `types_pb2.DataType` enum value. + + Raises: + TypeError: If `type_enum` is not a value `types_pb2.DataType`. + + """ + # TODO(mrry): Make the necessary changes (using __new__) to ensure + # that calling this returns one of the interned values. + type_enum = int(type_enum) + if (type_enum not in types_pb2.DataType.values() + or type_enum == types_pb2.DT_INVALID): + raise TypeError( + "type_enum is not a valid types_pb2.DataType: %s" % type_enum) + self._type_enum = type_enum + + @property + def is_ref_dtype(self): + """Returns `True` if this `DType` represents a reference type.""" + return self._type_enum > 100 + + @property + def as_ref(self): + """Returns a reference `DType` based on this `DType`.""" + if self.is_ref_dtype: + return self + else: + return _INTERN_TABLE[self._type_enum + 100] + + @property + def base_dtype(self): + """Returns a non-reference `DType` based on this `DType`.""" + if self.is_ref_dtype: + return _INTERN_TABLE[self._type_enum - 100] + else: + return self + + @property + def as_numpy_dtype(self): + """Returns a `numpy.dtype` based on this `DType`.""" + return _TF_TO_NP[self._type_enum] + + @property + def as_datatype_enum(self): + """Returns a `types_pb2.DataType` enum value based on this `DType`.""" + return self._type_enum + + @property + def is_integer(self): + """Returns whether this is a (non-quantized) integer type.""" + return (not self.is_quantized and + issubclass(self.as_numpy_dtype, np.integer)) + + @property + def is_quantized(self): + """Returns whether this is a quantized data type.""" + return self.base_dtype in [qint8, quint8, qint32, bfloat16] + + @property + def min(self): + """Returns the minimum representable value in this data type. + + Raises: + TypeError: if this is a non-numeric, unordered, or quantized type. + + """ + if (self.is_quantized or self.base_dtype == bool or + self.base_dtype == string or self.base_dtype == complex64): + raise TypeError("Cannot find minimum value of %s." % self) + + # there is no simple way to get the min value of a dtype, we have to check + # float and int types separately + try: + return np.finfo(self.as_numpy_dtype()).min + except: # bare except as possible raises by finfo not documented + try: + return np.iinfo(self.as_numpy_dtype()).min + except: + raise TypeError("Cannot find minimum value of %s." % self) + + @property + def max(self): + """Returns the maximum representable value in this data type. + + Raises: + TypeError: if this is a non-numeric, unordered, or quantized type. + + """ + if (self.is_quantized or self.base_dtype == bool or + self.base_dtype == string or self.base_dtype == complex64): + raise TypeError("Cannot find maximum value of %s." % self) + + # there is no simple way to get the min value of a dtype, we have to check + # float and int types separately + try: + return np.finfo(self.as_numpy_dtype()).max + except: # bare except as possible raises by finfo not documented + try: + return np.iinfo(self.as_numpy_dtype()).max + except: + raise TypeError("Cannot find maximum value of %s." % self) + + def is_compatible_with(self, other): + """Returns True if the `other` DType will be converted to this DType. + + The conversion rules are as follows: + + ``` + DType(T) .is_compatible_with(DType(T)) == True + DType(T) .is_compatible_with(DType(T).as_ref) == True + DType(T).as_ref.is_compatible_with(DType(T)) == False + DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True + ``` + + Args: + other: A `DType` (or object that may be converted to a `DType`). + + Returns: + True if a Tensor of the `other` `DType` will be implicitly converted to + this `DType`. + """ + other = as_dtype(other) + return self._type_enum in ( + other.as_datatype_enum, other.base_dtype.as_datatype_enum) + + def __eq__(self, other): + """Returns True iff this DType refers to the same type as `other`.""" + return (other is not None + and self._type_enum == as_dtype(other).as_datatype_enum) + + def __ne__(self, other): + """Returns True iff self != other.""" + return not self.__eq__(other) + + @property + def name(self): + """Returns the string name for this `DType`.""" + return _TYPE_TO_STRING[self._type_enum] + + def __str__(self): + return "<dtype: %r>" % self.name + + def __repr__(self): + return "tf." + self.name + + +# Define standard wrappers for the types_pb2.DataType enum. +float32 = DType(types_pb2.DT_FLOAT) +float64 = DType(types_pb2.DT_DOUBLE) +double = float64 +int32 = DType(types_pb2.DT_INT32) +uint8 = DType(types_pb2.DT_UINT8) +int16 = DType(types_pb2.DT_INT16) +int8 = DType(types_pb2.DT_INT8) +string = DType(types_pb2.DT_STRING) +complex64 = DType(types_pb2.DT_COMPLEX64) +int64 = DType(types_pb2.DT_INT64) +bool = DType(types_pb2.DT_BOOL) +qint8 = DType(types_pb2.DT_QINT8) +quint8 = DType(types_pb2.DT_QUINT8) +qint32 = DType(types_pb2.DT_QINT32) +bfloat16 = DType(types_pb2.DT_BFLOAT16) +float32_ref = DType(types_pb2.DT_FLOAT_REF) +float64_ref = DType(types_pb2.DT_DOUBLE_REF) +double_ref = float64_ref +int32_ref = DType(types_pb2.DT_INT32_REF) +uint8_ref = DType(types_pb2.DT_UINT8_REF) +int16_ref = DType(types_pb2.DT_INT16_REF) +int8_ref = DType(types_pb2.DT_INT8_REF) +string_ref = DType(types_pb2.DT_STRING_REF) +complex64_ref = DType(types_pb2.DT_COMPLEX64_REF) +int64_ref = DType(types_pb2.DT_INT64_REF) +bool_ref = DType(types_pb2.DT_BOOL_REF) +qint8_ref = DType(types_pb2.DT_QINT8_REF) +quint8_ref = DType(types_pb2.DT_QUINT8_REF) +qint32_ref = DType(types_pb2.DT_QINT32_REF) +bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF) + + +# Maintain an intern table so that we don't have to create a large +# number of small objects. +_INTERN_TABLE = { + types_pb2.DT_FLOAT: float32, + types_pb2.DT_DOUBLE: float64, + types_pb2.DT_INT32: int32, + types_pb2.DT_UINT8: uint8, + types_pb2.DT_INT16: int16, + types_pb2.DT_INT8: int8, + types_pb2.DT_STRING: string, + types_pb2.DT_COMPLEX64: complex64, + types_pb2.DT_INT64: int64, + types_pb2.DT_BOOL: bool, + types_pb2.DT_QINT8: qint8, + types_pb2.DT_QUINT8: quint8, + types_pb2.DT_QINT32: qint32, + types_pb2.DT_BFLOAT16: bfloat16, + types_pb2.DT_FLOAT_REF: float32_ref, + types_pb2.DT_DOUBLE_REF: float64_ref, + types_pb2.DT_INT32_REF: int32_ref, + types_pb2.DT_UINT8_REF: uint8_ref, + types_pb2.DT_INT16_REF: int16_ref, + types_pb2.DT_INT8_REF: int8_ref, + types_pb2.DT_STRING_REF: string_ref, + types_pb2.DT_COMPLEX64_REF: complex64_ref, + types_pb2.DT_INT64_REF: int64_ref, + types_pb2.DT_BOOL_REF: bool_ref, + types_pb2.DT_QINT8_REF: qint8_ref, + types_pb2.DT_QUINT8_REF: quint8_ref, + types_pb2.DT_QINT32_REF: qint32_ref, + types_pb2.DT_BFLOAT16_REF: bfloat16_ref, +} + + +# Standard mappings between types_pb2.DataType values and string names. +_TYPE_TO_STRING = { + types_pb2.DT_FLOAT: "float32", + types_pb2.DT_DOUBLE: "float64", + types_pb2.DT_INT32: "int32", + types_pb2.DT_UINT8: "uint8", + types_pb2.DT_INT16: "int16", + types_pb2.DT_INT8: "int8", + types_pb2.DT_STRING: "string", + types_pb2.DT_COMPLEX64: "complex64", + types_pb2.DT_INT64: "int64", + types_pb2.DT_BOOL: "bool", + types_pb2.DT_QINT8: "qint8", + types_pb2.DT_QUINT8: "quint8", + types_pb2.DT_QINT32: "qint32", + types_pb2.DT_BFLOAT16: "bfloat16", + types_pb2.DT_FLOAT_REF: "float32_ref", + types_pb2.DT_DOUBLE_REF: "float64_ref", + types_pb2.DT_INT32_REF: "int32_ref", + types_pb2.DT_UINT8_REF: "uint8_ref", + types_pb2.DT_INT16_REF: "int16_ref", + types_pb2.DT_INT8_REF: "int8_ref", + types_pb2.DT_STRING_REF: "string_ref", + types_pb2.DT_COMPLEX64_REF: "complex64_ref", + types_pb2.DT_INT64_REF: "int64_ref", + types_pb2.DT_BOOL_REF: "bool_ref", + types_pb2.DT_QINT8_REF: "qint8_ref", + types_pb2.DT_QUINT8_REF: "quint8_ref", + types_pb2.DT_QINT32_REF: "qint32_ref", + types_pb2.DT_BFLOAT16_REF: "bfloat16_ref", +} +_STRING_TO_TF = {value: _INTERN_TABLE[key] + for key, value in _TYPE_TO_STRING.iteritems()} +# Add non-canonical aliases. +_STRING_TO_TF["float"] = float32 +_STRING_TO_TF["float_ref"] = float32_ref +_STRING_TO_TF["double"] = float64 +_STRING_TO_TF["double_ref"] = float64_ref + + +# Numpy representation for quantized dtypes. +# +# These are magic strings that are used in the swig wrapper to identify +# quantized types. +# TODO(mrry,keveman): Investigate Numpy type registration to replace this +# hard-coding of names. +_np_qint8 = np.dtype([("qint8", np.int8, 1)]) +_np_quint8 = np.dtype([("quint8", np.uint8, 1)]) +_np_qint32 = np.dtype([("qint32", np.int32, 1)]) + +# Standard mappings between types_pb2.DataType values and numpy.dtypes. +_NP_TO_TF = frozenset([ + (np.float32, float32), + (np.float64, float64), + (np.int32, int32), + (np.int64, int64), + (np.uint8, uint8), + (np.int16, int16), + (np.int8, int8), + (np.complex64, complex64), + (np.object, string), + (np.bool, bool), + (_np_qint8, qint8), + (_np_quint8, quint8), + (_np_qint32, qint32), + # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16. +]) +_TF_TO_NP = { + types_pb2.DT_FLOAT: np.float32, + types_pb2.DT_DOUBLE: np.float64, + types_pb2.DT_INT32: np.int32, + types_pb2.DT_UINT8: np.uint8, + types_pb2.DT_INT16: np.int16, + types_pb2.DT_INT8: np.int8, + # NOTE(mdevin): For strings we use np.object as it supports variable length + # strings. + types_pb2.DT_STRING: np.object, + types_pb2.DT_COMPLEX64: np.complex64, + types_pb2.DT_INT64: np.int64, + types_pb2.DT_BOOL: np.bool, + types_pb2.DT_QINT8: _np_qint8, + types_pb2.DT_QUINT8: _np_quint8, + types_pb2.DT_QINT32: _np_qint32, + types_pb2.DT_BFLOAT16: np.uint16, + + # Ref types + types_pb2.DT_FLOAT_REF: np.float32, + types_pb2.DT_DOUBLE_REF: np.float64, + types_pb2.DT_INT32_REF: np.int32, + types_pb2.DT_UINT8_REF: np.uint8, + types_pb2.DT_INT16_REF: np.int16, + types_pb2.DT_INT8_REF: np.int8, + types_pb2.DT_STRING_REF: np.object, + types_pb2.DT_COMPLEX64_REF: np.complex64, + types_pb2.DT_INT64_REF: np.int64, + types_pb2.DT_BOOL_REF: np.bool, + types_pb2.DT_QINT8_REF: _np_qint8, + types_pb2.DT_QUINT8_REF: _np_quint8, + types_pb2.DT_QINT32_REF: _np_qint32, + types_pb2.DT_BFLOAT16_REF: np.uint16, +} + + +QUANTIZED_DTYPES = frozenset( + [qint8, quint8, qint32, qint8_ref, quint8_ref, qint32_ref]) + + +def as_dtype(type_value): + """Converts the given `type_value` to a `DType`. + + Args: + type_value: A value that can be converted to a `tf.DType` + object. This may currently be a `tf.DType` object, a + [`DataType` enum](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/types.proto), + a string type name, or a `numpy.dtype`. + + Returns: + A `DType` corresponding to `type_value`. + + Raises: + TypeError: If `type_value` cannot be converted to a `DType`. + """ + if isinstance(type_value, DType): + return type_value + + try: + return _INTERN_TABLE[type_value] + except KeyError: + pass + + try: + return _STRING_TO_TF[type_value] + except KeyError: + pass + + if isinstance(type_value, np.dtype): + # The numpy dtype for strings is variable length. We can not compare + # dtype with a single constant (np.string does not exist) to decide + # dtype is a "string" type. We need to compare the dtype.type to be + # sure it's a string type. + if type_value.type == np.string_ or type_value.type == np.unicode_: + return string + + for key, val in _NP_TO_TF: + if key == type_value: + return val + + raise TypeError( + "Cannot convert value %r to a TensorFlow DType." % type_value) diff --git a/tensorflow/python/framework/types_test.py b/tensorflow/python/framework/types_test.py new file mode 100644 index 0000000000..acd2994339 --- /dev/null +++ b/tensorflow/python/framework/types_test.py @@ -0,0 +1,174 @@ +"""Tests for tensorflow.python.framework.importer.""" +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import test_util +from tensorflow.python.framework import types +from tensorflow.python.platform import googletest + + +class TypesTest(test_util.TensorFlowTestCase): + + def testAllTypesConstructible(self): + for datatype_enum in types_pb2.DataType.values(): + if datatype_enum == types_pb2.DT_INVALID: + continue + self.assertEqual( + datatype_enum, types.DType(datatype_enum).as_datatype_enum) + + def testAllTypesConvertibleToDType(self): + for datatype_enum in types_pb2.DataType.values(): + if datatype_enum == types_pb2.DT_INVALID: + continue + self.assertEqual( + datatype_enum, types.as_dtype(datatype_enum).as_datatype_enum) + + def testAllTypesConvertibleToNumpyDtype(self): + for datatype_enum in types_pb2.DataType.values(): + if datatype_enum == types_pb2.DT_INVALID: + continue + dtype = types.as_dtype(datatype_enum) + numpy_dtype = dtype.as_numpy_dtype + _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype) + if dtype.base_dtype != types.bfloat16: + # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16. + self.assertEqual( + types.as_dtype(datatype_enum).base_dtype, types.as_dtype(numpy_dtype)) + + def testInvalid(self): + with self.assertRaises(TypeError): + types.DType(types_pb2.DT_INVALID) + with self.assertRaises(TypeError): + types.as_dtype(types_pb2.DT_INVALID) + + def testNumpyConversion(self): + self.assertIs(types.float32, types.as_dtype(np.float32)) + self.assertIs(types.float64, types.as_dtype(np.float64)) + self.assertIs(types.int32, types.as_dtype(np.int32)) + self.assertIs(types.int64, types.as_dtype(np.int64)) + self.assertIs(types.uint8, types.as_dtype(np.uint8)) + self.assertIs(types.int16, types.as_dtype(np.int16)) + self.assertIs(types.int8, types.as_dtype(np.int8)) + self.assertIs(types.complex64, types.as_dtype(np.complex64)) + self.assertIs(types.string, types.as_dtype(np.object)) + self.assertIs(types.string, types.as_dtype(np.array(["foo", "bar"]).dtype)) + self.assertIs(types.bool, types.as_dtype(np.bool)) + with self.assertRaises(TypeError): + types.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)])) + + def testStringConversion(self): + self.assertIs(types.float32, types.as_dtype("float32")) + self.assertIs(types.float64, types.as_dtype("float64")) + self.assertIs(types.int32, types.as_dtype("int32")) + self.assertIs(types.uint8, types.as_dtype("uint8")) + self.assertIs(types.int16, types.as_dtype("int16")) + self.assertIs(types.int8, types.as_dtype("int8")) + self.assertIs(types.string, types.as_dtype("string")) + self.assertIs(types.complex64, types.as_dtype("complex64")) + self.assertIs(types.int64, types.as_dtype("int64")) + self.assertIs(types.bool, types.as_dtype("bool")) + self.assertIs(types.qint8, types.as_dtype("qint8")) + self.assertIs(types.quint8, types.as_dtype("quint8")) + self.assertIs(types.qint32, types.as_dtype("qint32")) + self.assertIs(types.bfloat16, types.as_dtype("bfloat16")) + self.assertIs(types.float32_ref, types.as_dtype("float32_ref")) + self.assertIs(types.float64_ref, types.as_dtype("float64_ref")) + self.assertIs(types.int32_ref, types.as_dtype("int32_ref")) + self.assertIs(types.uint8_ref, types.as_dtype("uint8_ref")) + self.assertIs(types.int16_ref, types.as_dtype("int16_ref")) + self.assertIs(types.int8_ref, types.as_dtype("int8_ref")) + self.assertIs(types.string_ref, types.as_dtype("string_ref")) + self.assertIs(types.complex64_ref, types.as_dtype("complex64_ref")) + self.assertIs(types.int64_ref, types.as_dtype("int64_ref")) + self.assertIs(types.bool_ref, types.as_dtype("bool_ref")) + self.assertIs(types.qint8_ref, types.as_dtype("qint8_ref")) + self.assertIs(types.quint8_ref, types.as_dtype("quint8_ref")) + self.assertIs(types.qint32_ref, types.as_dtype("qint32_ref")) + self.assertIs(types.bfloat16_ref, types.as_dtype("bfloat16_ref")) + with self.assertRaises(TypeError): + types.as_dtype("not_a_type") + + def testDTypesHaveUniqueNames(self): + dtypes = [] + names = set() + for datatype_enum in types_pb2.DataType.values(): + if datatype_enum == types_pb2.DT_INVALID: + continue + dtype = types.as_dtype(datatype_enum) + dtypes.append(dtype) + names.add(dtype.name) + self.assertEqual(len(dtypes), len(names)) + + def testIsInteger(self): + self.assertEqual(types.as_dtype("int8").is_integer, True) + self.assertEqual(types.as_dtype("int16").is_integer, True) + self.assertEqual(types.as_dtype("int32").is_integer, True) + self.assertEqual(types.as_dtype("int64").is_integer, True) + self.assertEqual(types.as_dtype("uint8").is_integer, True) + self.assertEqual(types.as_dtype("complex64").is_integer, False) + self.assertEqual(types.as_dtype("float").is_integer, False) + self.assertEqual(types.as_dtype("double").is_integer, False) + self.assertEqual(types.as_dtype("string").is_integer, False) + self.assertEqual(types.as_dtype("bool").is_integer, False) + + def testMinMax(self): + # make sure min/max evaluates for all data types that have min/max + for datatype_enum in types_pb2.DataType.values(): + if datatype_enum == types_pb2.DT_INVALID: + continue + dtype = types.as_dtype(datatype_enum) + numpy_dtype = dtype.as_numpy_dtype + + # ignore types for which there are no minimum/maximum (or we cannot + # compute it, such as for the q* types) + if (dtype.is_quantized or + dtype.base_dtype == types.bool or + dtype.base_dtype == types.string or + dtype.base_dtype == types.complex64): + continue + + print "%s: %s - %s" % (dtype, dtype.min, dtype.max) + + # check some values that are known + if numpy_dtype == np.bool_: + self.assertEquals(dtype.min, 0) + self.assertEquals(dtype.max, 1) + if numpy_dtype == np.int8: + self.assertEquals(dtype.min, -128) + self.assertEquals(dtype.max, 127) + if numpy_dtype == np.int16: + self.assertEquals(dtype.min, -32768) + self.assertEquals(dtype.max, 32767) + if numpy_dtype == np.int32: + self.assertEquals(dtype.min, -2147483648) + self.assertEquals(dtype.max, 2147483647) + if numpy_dtype == np.int64: + self.assertEquals(dtype.min, -9223372036854775808) + self.assertEquals(dtype.max, 9223372036854775807) + if numpy_dtype == np.uint8: + self.assertEquals(dtype.min, 0) + self.assertEquals(dtype.max, 255) + if numpy_dtype == np.uint16: + self.assertEquals(dtype.min, 0) + self.assertEquals(dtype.max, 4294967295) + if numpy_dtype == np.uint32: + self.assertEquals(dtype.min, 0) + self.assertEquals(dtype.max, 18446744073709551615) + if numpy_dtype in (np.float16, np.float32, np.float64): + self.assertEquals(dtype.min, np.finfo(numpy_dtype).min) + self.assertEquals(dtype.max, np.finfo(numpy_dtype).max) + + def testRepr(self): + for enum, name in types._TYPE_TO_STRING.iteritems(): + dtype = types.DType(enum) + self.assertEquals(repr(dtype), 'tf.' + name) + dtype2 = eval(repr(dtype)) + self.assertEquals(type(dtype2), types.DType) + self.assertEquals(dtype, dtype2) + + +if __name__ == "__main__": + googletest.main() |