diff options
Diffstat (limited to 'tensorflow/python/client')
-rwxr-xr-x | tensorflow/python/client/__init__.py | 0 | ||||
-rw-r--r-- | tensorflow/python/client/client_lib.py | 40 | ||||
-rw-r--r-- | tensorflow/python/client/events_writer.i | 34 | ||||
-rw-r--r-- | tensorflow/python/client/events_writer_test.py | 54 | ||||
-rw-r--r-- | tensorflow/python/client/graph_util.py | 138 | ||||
-rw-r--r-- | tensorflow/python/client/graph_util_test.py | 126 | ||||
-rw-r--r-- | tensorflow/python/client/notebook.py | 104 | ||||
-rw-r--r-- | tensorflow/python/client/session.py | 567 | ||||
-rw-r--r-- | tensorflow/python/client/session_test.py | 555 | ||||
-rw-r--r-- | tensorflow/python/client/tensorflow_server.i | 16 | ||||
-rw-r--r-- | tensorflow/python/client/test_construction_fails_op.cc | 22 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session.i | 235 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.cc | 518 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.h | 56 |
14 files changed, 2465 insertions, 0 deletions
diff --git a/tensorflow/python/client/__init__.py b/tensorflow/python/client/__init__.py new file mode 100755 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/python/client/__init__.py diff --git a/tensorflow/python/client/client_lib.py b/tensorflow/python/client/client_lib.py new file mode 100644 index 0000000000..9148ed17c0 --- /dev/null +++ b/tensorflow/python/client/client_lib.py @@ -0,0 +1,40 @@ +# pylint: disable=wildcard-import,unused-import,g-bad-import-order,line-too-long +"""This library contains classes for launching graphs and executing operations. + +The [basic usage](../../get_started/index.md#basic-usage) guide has +examples of how a graph is launched in a [`tf.Session`](#Session). + +## Session management + +@@Session + +@@get_default_session + +## Error classes + +@@OpError +@@CancelledError +@@UnknownError +@@InvalidArgumentError +@@DeadlineExceededError +@@NotFoundError +@@AlreadyExistsError +@@PermissionDeniedError +@@UnauthenticatedError +@@ResourceExhaustedError +@@FailedPreconditionError +@@AbortedError +@@OutOfRangeError +@@UnimplementedError +@@InternalError +@@UnavailableError +@@DataLossError +""" + +from tensorflow.python.client.session import InteractiveSession +from tensorflow.python.client.session import Session + +from tensorflow.python.framework import errors +from tensorflow.python.framework.errors import OpError + +from tensorflow.python.framework.ops import get_default_session diff --git a/tensorflow/python/client/events_writer.i b/tensorflow/python/client/events_writer.i new file mode 100644 index 0000000000..cbf42e2791 --- /dev/null +++ b/tensorflow/python/client/events_writer.i @@ -0,0 +1,34 @@ +%include "tensorflow/python/platform/base.i" + +%{ +#include "tensorflow/core/util/events_writer.h" +#include "tensorflow/core/util/event.pb.h" +%} + +%nodefaultctor EventsWriter; + +%ignoreall +%unignore tensorflow; +%unignore tensorflow::EventsWriter; +%unignore tensorflow::EventsWriter::EventsWriter; +%unignore tensorflow::EventsWriter::~EventsWriter; +%unignore tensorflow::EventsWriter::FileName; +%rename("_WriteSerializedEvent") tensorflow::EventsWriter::WriteSerializedEvent; +%unignore tensorflow::EventsWriter::Flush; +%unignore tensorflow::EventsWriter::Close; +%include "tensorflow/core/util/events_writer.h" +%unignoreall + +%newobject tensorflow::EventsWriter::EventsWriter; + + +%extend tensorflow::EventsWriter { +%insert("python") %{ + def WriteEvent(self, event): + from tensorflow.core.util.event_pb2 import Event + if not isinstance(event, Event): + raise TypeError("Expected an event_pb2.Event proto, " + " but got %s" % type(event)) + return self._WriteSerializedEvent(event.SerializeToString()) +%} +} diff --git a/tensorflow/python/client/events_writer_test.py b/tensorflow/python/client/events_writer_test.py new file mode 100644 index 0000000000..60bce49b1f --- /dev/null +++ b/tensorflow/python/client/events_writer_test.py @@ -0,0 +1,54 @@ +"""Tests for the SWIG-wrapped events writer.""" +import os.path + +from tensorflow.core.framework import summary_pb2 +from tensorflow.core.util import event_pb2 +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.lib.io import tf_record +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class PywrapeventsWriterTest(test_util.TensorFlowTestCase): + + def testWriteEvents(self): + file_prefix = os.path.join(self.get_temp_dir(), "events") + writer = pywrap_tensorflow.EventsWriter(file_prefix) + filename = writer.FileName() + event_written = event_pb2.Event( + wall_time=123.45, step=67, + summary=summary_pb2.Summary( + value=[summary_pb2.Summary.Value(tag="foo", simple_value=89.0)])) + writer.WriteEvent(event_written) + writer.Flush() + writer.Close() + + with self.assertRaises(IOError): + for r in tf_record.tf_record_iterator(filename + "DOES_NOT_EXIST"): + self.assertTrue(False) + + reader = tf_record.tf_record_iterator(filename) + event_read = event_pb2.Event() + + event_read.ParseFromString(next(reader)) + self.assertTrue(event_read.HasField("file_version")) + + event_read.ParseFromString(next(reader)) + # Second event + self.assertProtoEquals(""" + wall_time: 123.45 step: 67 + summary { value { tag: 'foo' simple_value: 89.0 } } + """, event_read) + + with self.assertRaises(StopIteration): + next(reader) + + def testWriteEventInvalidType(self): + class _Invalid(object): + def __str__(self): return "Invalid" + with self.assertRaisesRegexp(TypeError, "Invalid"): + pywrap_tensorflow.EventsWriter("foo").WriteEvent(_Invalid()) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/client/graph_util.py b/tensorflow/python/client/graph_util.py new file mode 100644 index 0000000000..4c65a445ae --- /dev/null +++ b/tensorflow/python/client/graph_util.py @@ -0,0 +1,138 @@ +"""Helpers to manipulate a tensor graph in python. +""" + +import tensorflow.python.platform + +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import ops +from tensorflow.python.framework import types +from tensorflow.python.platform import logging + +_VARIABLE_OPS = { + "Assign", + "AssignAdd", + "AssignSub", + "Queue", + "RandomParameters", + "ScatterAdd", + "ScatterSub", + "ScatterUpdate", + "Variable", +} + + +def _is_variable_op(op): + """Returns true if 'op' refers to a Variable node.""" + return op in _VARIABLE_OPS + + +def set_cpu0(device_string): + """Creates a new device string based on `device_string' but using /CPU:0. + + If the device is already on /CPU:0, this is a no-op. + + Args: + device_string: A device string. + + Returns: + A device string. + """ + parsed_device = pydev.from_string(device_string) + parsed_device.device_type = "CPU" + parsed_device.device_index = 0 + return parsed_device.to_string() + + +def must_run_on_cpu(node, pin_variables_on_cpu=False): + """Returns True if the given node_def must run on CPU, otherwise False. + + Args: + node: The node to be assigned to a device. Could be either an ops.Operation + or NodeDef. + pin_variables_on_cpu: If True, this function will return False if node_def + represents a variable-related op. + + Returns: + True if the given node must run on CPU, otherwise False. + """ + + if isinstance(node, ops.Operation): + node_def = node.node_def + else: + assert isinstance(node, graph_pb2.NodeDef) + node_def = node + + # If the op is a variable-related op, should we pin it on CPU? + if pin_variables_on_cpu and _is_variable_op(node_def.op): + return True + + # Constant operations producing a string or int32 must run on CPU. + if node_def.op == "Const": + # Get the value of the 'dtype' attr + dtype = node_def.attr["dtype"].type + if dtype == types.string or dtype == types.int32: + return True + + if node_def.op == "DynamicStitch": + dtype = node_def.attr["T"].type + if dtype == types.int32: + # DynamicStitch on GPU only works for int32 values. + return True + + if node_def.op in ["Cast"]: + dtype = node_def.attr["SrcT"].type + if dtype == types.int32: + # Cast on GPU does not works for int32 values. + return True + return False + + +################################################################################ +# +# device functions for use in with g.device(...) +# +################################################################################ + + +def pin_variables_on_cpu(op): + """Returns a CPU device for Variable nodes if the device is not specified. + + Args: + op: The ops.Operation object describing the node for which a device + should be chosen. The op.device field is respected. + + Returns: + A device containing "/device:CPU:0" if the node is related to a variable. + """ + device = op.device if op.device is not None else "" + dev = pydev.from_string(device) + + # If a device type exists already, do not override. + if dev.device_type: + return device + + if isinstance(op, ops.Operation): + node_def = op.node_def + else: + assert isinstance(op, graph_pb2.NodeDef) + node_def = op + + if _is_variable_op(node_def.op): + return set_cpu0(device) + return device + + +def pin_to_cpu(op): + """Returns a CPU device for the given node.""" + device = op.device if op.device is not None else "" + dev = pydev.from_string(device) + + if not dev.device_type: + return set_cpu0(device) + if dev.device_type == "CPU": + return device + + logging.info("Operation %s has been assigned to a non-CPU (%s), so " + "it will not be pinned to the CPU.", op.name, dev.device_type) + return device diff --git a/tensorflow/python/client/graph_util_test.py b/tensorflow/python/client/graph_util_test.py new file mode 100644 index 0000000000..8066f722a8 --- /dev/null +++ b/tensorflow/python/client/graph_util_test.py @@ -0,0 +1,126 @@ +"""Tests for tensorflow.python.client.graph_util.""" +import tensorflow.python.platform + +from tensorflow.python.client import graph_util +from tensorflow.python.framework import ops +from tensorflow.python.framework import types +from tensorflow.python.ops import constant_op +from tensorflow.python.ops import data_flow_ops +# pylint: disable=unused-import +from tensorflow.python.ops import math_ops +# pylint: enable=unused-import +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import googletest + + +class DeviceFunctionsTest(googletest.TestCase): + + def testPinToCpu(self): + with ops.Graph().as_default() as g, g.device(graph_util.pin_to_cpu): + const_a = constant_op.constant(5.0) + const_b = constant_op.constant(10.0) + add_c = const_a + const_b + var_v = state_ops.variable_op([], dtype=types.float32) + assign_c_to_v = state_ops.assign(var_v, add_c) + const_string = constant_op.constant("on a cpu") + dynamic_stitch_int_result = data_flow_ops.dynamic_stitch( + [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]]) + dynamic_stitch_float_result = data_flow_ops.dynamic_stitch( + [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]]) + self.assertEqual(const_a.device, "/device:CPU:0") + self.assertEqual(const_b.device, "/device:CPU:0") + self.assertEqual(add_c.device, "/device:CPU:0") + self.assertEqual(var_v.device, "/device:CPU:0") + self.assertEqual(assign_c_to_v.device, "/device:CPU:0") + self.assertEqual(const_string.device, "/device:CPU:0") + self.assertEqual(dynamic_stitch_int_result.device, "/device:CPU:0") + self.assertEqual(dynamic_stitch_float_result.device, "/device:CPU:0") + + def testPinRequiredOpsOnCPU(self): + with ops.Graph().as_default() as g, g.device( + graph_util.pin_variables_on_cpu): + const_a = constant_op.constant(5.0) + const_b = constant_op.constant(10.0) + add_c = const_a + const_b + var_v = state_ops.variable_op([], dtype=types.float32) + assign_c_to_v = state_ops.assign(var_v, add_c) + dynamic_stitch_int_result = data_flow_ops.dynamic_stitch( + [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]]) + dynamic_stitch_float_result = data_flow_ops.dynamic_stitch( + [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]]) + # Non-variable ops shuld not specify a device + self.assertEqual(const_a.device, None) + self.assertEqual(const_b.device, None) + self.assertEqual(add_c.device, None) + # Variable ops specify a device + self.assertEqual(var_v.device, "/device:CPU:0") + self.assertEqual(assign_c_to_v.device, "/device:CPU:0") + + def testTwoDeviceFunctions(self): + with ops.Graph().as_default() as g: + var_0 = state_ops.variable_op([1], dtype=types.float32) + with g.device(graph_util.pin_variables_on_cpu): + var_1 = state_ops.variable_op([1], dtype=types.float32) + var_2 = state_ops.variable_op([1], dtype=types.float32) + var_3 = state_ops.variable_op([1], dtype=types.float32) + with g.device(graph_util.pin_variables_on_cpu): + var_4 = state_ops.variable_op([1], dtype=types.float32) + with g.device("/device:GPU:0"): + var_5 = state_ops.variable_op([1], dtype=types.float32) + var_6 = state_ops.variable_op([1], dtype=types.float32) + + self.assertEqual(var_0.device, None) + self.assertEqual(var_1.device, "/device:CPU:0") + self.assertEqual(var_2.device, None) + self.assertEqual(var_3.device, None) + self.assertEqual(var_4.device, "/device:CPU:0") + self.assertEqual(var_5.device, "/device:GPU:0") + self.assertEqual(var_6.device, "/device:CPU:0") + + def testExplicitDevice(self): + with ops.Graph().as_default() as g: + const_0 = constant_op.constant(5.0) + with g.device("/device:GPU:0"): + const_1 = constant_op.constant(5.0) + with g.device("/device:GPU:1"): + const_2 = constant_op.constant(5.0) + with g.device("/device:CPU:0"): + const_3 = constant_op.constant(5.0) + with g.device("/device:CPU:1"): + const_4 = constant_op.constant(5.0) + with g.device("/job:ps"): + const_5 = constant_op.constant(5.0) + + self.assertEqual(const_0.device, None) + self.assertEqual(const_1.device, "/device:GPU:0") + self.assertEqual(const_2.device, "/device:GPU:1") + self.assertEqual(const_3.device, "/device:CPU:0") + self.assertEqual(const_4.device, "/device:CPU:1") + self.assertEqual(const_5.device, "/job:ps") + + def testDefaultDevice(self): + with ops.Graph().as_default() as g, g.device( + graph_util.pin_variables_on_cpu): + with g.device("/job:ps"): + const_0 = constant_op.constant(5.0) + with g.device("/device:GPU:0"): + const_1 = constant_op.constant(5.0) + with g.device("/device:GPU:1"): + const_2 = constant_op.constant(5.0) + with g.device("/device:CPU:0"): + const_3 = constant_op.constant(5.0) + with g.device("/device:CPU:1"): + const_4 = constant_op.constant(5.0) + with g.device("/replica:0"): + const_5 = constant_op.constant(5.0) + + self.assertEqual(const_0.device, "/job:ps") + self.assertEqual(const_1.device, "/device:GPU:0") + self.assertEqual(const_2.device, "/device:GPU:1") + self.assertEqual(const_3.device, "/device:CPU:0") + self.assertEqual(const_4.device, "/device:CPU:1") + self.assertEqual(const_5.device, "/replica:0") + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/client/notebook.py b/tensorflow/python/client/notebook.py new file mode 100644 index 0000000000..1871fbc632 --- /dev/null +++ b/tensorflow/python/client/notebook.py @@ -0,0 +1,104 @@ +"""Notebook front-end to TensorFlow. + +When you run this binary, you'll see something like below, which indicates +the serving URL of the notebook: + + The IPython Notebook is running at: http://127.0.0.1:8888/ + +Press "Shift+Enter" to execute a cell +Press "Enter" on a cell to go into edit mode. +Press "Escape" to go back into command mode and use arrow keys to navigate. +Press "a" in command mode to insert cell above or "b" to insert cell below. + +Your root notebooks directory is FLAGS.notebook_dir +""" + + +import os +import socket +import sys + +# pylint: disable=g-import-not-at-top +# Official recommended way of turning on fast protocol buffers as of 10/21/14 +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp" +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2" + +from tensorflow.python.platform import app +from tensorflow.python.platform import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "password", None, + "Password to require. If set, the server will allow public access." + " Only used if notebook config file does not exist.") + +flags.DEFINE_string("notebook_dir", "experimental/brain/notebooks", + "root location where to store notebooks") + +ORIG_ARGV = sys.argv +# Main notebook process calls itself with argv[1]="kernel" to start kernel +# subprocesses. +IS_KERNEL = len(sys.argv) > 1 and sys.argv[1] == "kernel" + + +def main(unused_argv): + sys.argv = ORIG_ARGV + + if not IS_KERNEL: + # Drop all flags. + sys.argv = [sys.argv[0]] + # NOTE(sadovsky): For some reason, putting this import at the top level + # breaks inline plotting. It's probably a bug in the stone-age version of + # matplotlib. + from IPython.html.notebookapp import NotebookApp # pylint: disable=g-import-not-at-top + notebookapp = NotebookApp.instance() + notebookapp.open_browser = True + + # password functionality adopted from quality/ranklab/main/tools/notebook.py + # add options to run with "password" + if FLAGS.password: + from IPython.lib import passwd # pylint: disable=g-import-not-at-top + notebookapp.ip = "0.0.0.0" + notebookapp.password = passwd(FLAGS.password) + else: + print ("\nNo password specified; Notebook server will only be available" + " on the local machine.\n") + notebookapp.initialize(argv=["--notebook-dir", FLAGS.notebook_dir]) + + if notebookapp.ip == "0.0.0.0": + proto = "https" if notebookapp.certfile else "http" + url = "%s://%s:%d%s" % (proto, socket.gethostname(), notebookapp.port, + notebookapp.base_project_url) + print "\nNotebook server will be publicly available at: %s\n" % url + + notebookapp.start() + return + + # Drop the --flagfile flag so that notebook doesn't complain about an + # "unrecognized alias" when parsing sys.argv. + sys.argv = ([sys.argv[0]] + + [z for z in sys.argv[1:] if not z.startswith("--flagfile")]) + from IPython.kernel.zmq.kernelapp import IPKernelApp # pylint: disable=g-import-not-at-top + kernelapp = IPKernelApp.instance() + kernelapp.initialize() + + # Enable inline plotting. Equivalent to running "%matplotlib inline". + ipshell = kernelapp.shell + ipshell.enable_matplotlib("inline") + + kernelapp.start() + + +if __name__ == "__main__": + # When the user starts the main notebook process, we don't touch sys.argv. + # When the main process launches kernel subprocesses, it writes all flags + # to a tmpfile and sets --flagfile to that tmpfile, so for kernel + # subprocesses here we drop all flags *except* --flagfile, then call + # app.run(), and then (in main) restore all flags before starting the + # kernel app. + if IS_KERNEL: + # Drop everything except --flagfile. + sys.argv = ([sys.argv[0]] + + [x for x in sys.argv[1:] if x.startswith("--flagfile")]) + app.run() diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py new file mode 100644 index 0000000000..7da9b41cf4 --- /dev/null +++ b/tensorflow/python/client/session.py @@ -0,0 +1,567 @@ +"""A client interface for TensorFlow.""" + +import re +import sys +import threading + +import tensorflow.python.platform + +import numpy as np + +from tensorflow.python import pywrap_tensorflow as tf_session +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.platform import logging + + +class SessionInterface(object): + """Base class for implementations of TensorFlow client sessions.""" + + @property + def graph(self): + """The underlying TensorFlow graph, to be used in building Operations.""" + raise NotImplementedError('graph') + + @property + def sess_str(self): + """The TensorFlow process to which this session will connect.""" + raise NotImplementedError('sess_str') + + def run(self, fetches, feed_dict=None): + """Runs operations in the session. See `Session.run()` for details.""" + raise NotImplementedError('Run') + + +class BaseSession(SessionInterface): + """A class for interacting with a TensorFlow computation. + + The BaseSession enables incremental graph building with inline + execution of Operations and evaluation of Tensors. + """ + + def __init__(self, target='', graph=None, config=None): + """Constructs a new TensorFlow session. + + Args: + target: (Optional) The TensorFlow execution engine to connect to. + graph: (Optional) The graph to be used. If this argument is None, + the default graph will be used. + config: (Optional) ConfigProto proto used to configure the session. + + Raises: + RuntimeError: If an error occurs while creating the TensorFlow + session. + """ + if graph is None: + self._graph = ops.get_default_graph() + else: + self._graph = graph + + self._opened = False + self._closed = False + + self._current_version = 0 + self._extend_lock = threading.Lock() + self._target = target + + self._session = None + + try: + opts = tf_session.TF_NewSessionOptions(target=target, config=config) + status = tf_session.TF_NewStatus() + self._session = tf_session.TF_NewSession(opts, status) + if tf_session.TF_GetCode(status) != 0: + message = tf_session.TF_Message(status) + raise RuntimeError(message) + + finally: + tf_session.TF_DeleteSessionOptions(opts) + tf_session.TF_DeleteStatus(status) + + def close(self): + """Closes this session. + + Calling this method frees all resources associated with the session. + + Raises: + RuntimeError: If an error occurs while closing the session. + """ + with self._extend_lock: + if self._opened and not self._closed: + self._closed = True + try: + status = tf_session.TF_NewStatus() + tf_session.TF_CloseSession(self._session, status) + if tf_session.TF_GetCode(status) != 0: + raise RuntimeError(tf_session.TF_Message(status)) + finally: + tf_session.TF_DeleteStatus(status) + + def __del__(self): + self.close() + try: + status = tf_session.TF_NewStatus() + if self._session is not None: + tf_session.TF_DeleteSession(self._session, status) + if tf_session.TF_GetCode(status) != 0: + raise RuntimeError(tf_session.TF_Message(status)) + self._session = None + finally: + tf_session.TF_DeleteStatus(status) + + @property + def graph(self): + """The graph that was launched in this session.""" + return self._graph + + @property + def graph_def(self): + """A serializable version of the underlying TensorFlow graph. + + Returns: + A graph_pb2.GraphDef proto containing nodes for all of the Operations in + the underlying TensorFlow graph. + """ + return self._graph.as_graph_def() + + @property + def sess_str(self): + return self._target + + def as_default(self): + """Returns a context manager that makes this object the default session. + + Use with the `with` keyword to specify that calls to + [`Operation.run()`](framework.md#Operation.run) or + [`Tensor.run()`](framework.md#Tensor.run) should be executed in + this session. + + ```python + c = tf.constant(..) + sess = tf.Session() + + with sess.as_default(): + assert tf.get_default_session() is sess + print c.eval() + ``` + + To get the current default session, use + [`tf.get_default_session()`](#get_default_session). + + + *N.B.* The `as_default` context manager *does not* close the + session when you exit the context, and you must close the session + explicitly. + + ```python + c = tf.constant(...) + sess = tf.Session() + with sess.as_default(): + print c.eval() + # ... + with sess.as_default(): + print c.eval() + + sess.close() + ``` + + Alternatively, you can use `with tf.Session():` to create a + session that is automatically closed on exiting the context, + including when an uncaught exception is raised. + + *N.B.* The default graph is a property of the current thread. If you + create a new thread, and wish to use the default session in that + thread, you must explicitly add a `with sess.as_default():` in that + thread's function. + + Returns: + A context manager using this session as the default session. + + """ + return ops.default_session(self) + + # Eventually, this registration could be opened up to support custom + # Tensor expansions. Expects tuples of (Type, fetch_fn, feed_fn), + # where the signatures are: + # fetch_fn : Type -> (list of Tensors, + # lambda: list of fetched np.ndarray -> TypeVal) + # feed_fn : Type, TypeVal -> list of (Tensor, value) + # Conceptually, fetch_fn describes how to expand fetch into its + # component Tensors and how to contracting the fetched results back into + # a single return value. feed_fn describes how to unpack a single fed + # value and map it to feeds of a Tensor and its corresponding value. + # pylint: disable=g-long-lambda + _REGISTERED_EXPANSIONS = [ + # SparseTensors are fetched as SparseTensorValues. They can be fed + # SparseTensorValues or normal tuples. + (ops.SparseTensor, + lambda fetch: ( + [fetch.indices, fetch.values, fetch.shape], + lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)), + lambda feed, feed_val: list(zip( + [feed.indices, feed.values, feed.shape], feed_val))), + # The default catches all types and performs no expansions. + (object, + lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), + lambda feed, feed_val: [(feed, feed_val)])] + # pylint: enable=g-long-lambda + + def run(self, fetches, feed_dict=None): + """Runs the operations and evaluates the tensors in `fetches`. + + This method runs one "step" of TensorFlow computation, by + running the necessary graph fragment to execute every `Operation` + and evaluate every `Tensor` in `fetches`, substituting the values in + `feed_dict` for the corresponding input values. + + The `fetches` argument may be a list of graph elements or a single + graph element, and these determine the return value of this + method. A graph element can be one of the following types: + + * If the *i*th element of `fetches` is an + [`Operation`](framework.md#Operation), the *i*th return value + will be `None`. + * If the *i*th element of `fetches` is a + [`Tensor`](framework.md#Tensor), the *i*th return value will + be a numpy ndarray containing the value of that tensor. + * If the *i*th element of `fetches` is a + [`SparseTensor`](sparse_ops.md#SparseTensor), the *i*th + return value will be a + [`SparseTensorValue`](sparse_ops.md#SparseTensorValue) + containing the value of that sparse tensor. + + The optional `feed_dict` argument allows the caller to override + the value of tensors in the graph. Each key in `feed_dict` can be + one of the following types: + + * If the key is a [`Tensor`](framework.md#Tensor), the + value may be a Python scalar, string, list, or numpy ndarray + that can be converted to the same `dtype` as that + tensor. Additionally, if the key is a + [placeholder](io_ops.md#placeholder), the shape of the value + will be checked for compatibility with the placeholder. + * If the key is a [`SparseTensor`](sparse_ops.md#SparseTensor), + the value should be a + [`SparseTensorValue`](sparse_ops.md#SparseTensorValue). + + Args: + fetches: A single graph element, or a list of graph elements + (described above). + feed_dict: A dictionary that maps graph elements to values + (described above). + + Returns: + Either a single value if `fetches` is a single graph element, or + a list of values if `fetches` is a list (described above). + + Raises: + RuntimeError: If this `Session` is in an invalid state (e.g. has been + closed). + TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. + ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a + `Tensor` that doesn't exist. + + """ + def _fetch_fn(fetch): + for tensor_type, fetch_fn, _ in BaseSession._REGISTERED_EXPANSIONS: + if isinstance(fetch, tensor_type): + return fetch_fn(fetch) + raise TypeError('Fetch argument %r has invalid type %r' + % (fetch, type(fetch))) + + def _feed_fn(feed, feed_val): + for tensor_type, _, feed_fn in BaseSession._REGISTERED_EXPANSIONS: + if isinstance(feed, tensor_type): + return feed_fn(feed, feed_val) + raise TypeError('Feed argument %r has invalid type %r' + % (feed, type(feed))) + + # Check session. + if self._closed: + raise RuntimeError('Attempted to use a closed Session.') + + # Validate and process fetches. + is_list_fetch = isinstance(fetches, (list, tuple)) + if not is_list_fetch: + fetches = [fetches] + + unique_fetch_targets = set() + target_list = [] + + fetch_info = [] + for fetch in fetches: + subfetches, fetch_contraction_fn = _fetch_fn(fetch) + subfetch_names = [] + for subfetch in subfetches: + try: + fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True, + allow_operation=True) + if isinstance(fetch_t, ops.Operation): + target_list.append(fetch_t.name) + else: + subfetch_names.append(fetch_t.name) + except TypeError as e: + raise TypeError('Fetch argument %r of %r has invalid type %r, ' + 'must be a string or Tensor. (%s)' + % (subfetch, fetch, type(subfetch), e.message)) + except ValueError as e: + raise ValueError('Fetch argument %r of %r cannot be interpreted as a ' + 'Tensor. (%s)' % (subfetch, fetch, e.message)) + except KeyError as e: + raise ValueError('Fetch argument %r of %r cannot be interpreted as a ' + 'Tensor. (%s)' % (subfetch, fetch, e.message)) + unique_fetch_targets.update(subfetch_names) + fetch_info.append((subfetch_names, fetch_contraction_fn)) + + unique_fetch_targets = list(unique_fetch_targets) + + # Create request. + feed_dict_string = {} + + # Validate and process feed_dict. + if feed_dict: + for feed, feed_val in feed_dict.iteritems(): + for subfeed, subfeed_val in _feed_fn(feed, feed_val): + try: + subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True, + allow_operation=False) + except Exception as e: + e.message = ('Cannot interpret feed_dict key as Tensor: ' + + e.message) + e.args = (e.message,) + raise e + np_val = np.array(subfeed_val, dtype=subfeed_t.dtype.as_numpy_dtype) + if subfeed_t.op.type == 'Placeholder': + if not subfeed_t.get_shape().is_compatible_with(np_val.shape): + raise ValueError( + 'Cannot feed value of shape %r for Tensor %r, ' + 'which has shape %r' + % (np_val.shape, subfeed_t.name, + tuple(subfeed_t.get_shape().dims))) + feed_dict_string[str(subfeed_t.name)] = np_val + + # Run request and get response. + results = self._do_run(target_list, unique_fetch_targets, feed_dict_string) + + # User may have fetched the same tensor multiple times, but we + # only fetch them from the runtime once. Furthermore, they may + # be wrapped as a tuple of tensors. Here we map the results back + # to what the client asked for. + fetched_results = dict(zip(unique_fetch_targets, results)) + ret = [] + for fetch_names, fetch_contraction_fn in fetch_info: + if fetch_names: + fetched_vals = [fetched_results[name] for name in fetch_names] + ret.append(fetch_contraction_fn(fetched_vals)) + else: + ret.append(None) + + if is_list_fetch: + return ret + else: + return ret[0] + + # Captures the name of a node in an error status. + _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =') + + def _do_run(self, target_list, fetch_list, feed_dict): + """Runs a step based on the given fetches and feeds. + + Args: + target_list: A list of strings corresponding to names of tensors + or operations to be run to, but not fetched. + fetch_list: A list of strings corresponding to names of tensors to be + fetched and operations to be run. + feed_dict: A dictionary that maps tensor names to numpy ndarrays. + + Returns: + A list of numpy ndarrays, corresponding to the elements of + `fetch_list`. If the ith element of `fetch_list` contains the + name of an operation, the first Tensor output of that operation + will be returned for that element. + """ + try: + # Ensure any changes to the graph are reflected in the runtime. + with self._extend_lock: + if self._graph.version > self._current_version: + graph_def = self._graph.as_graph_def( + from_version=self._current_version) + + try: + status = tf_session.TF_NewStatus() + tf_session.TF_ExtendGraph( + self._session, graph_def.SerializeToString(), status) + if tf_session.TF_GetCode(status) != 0: + raise RuntimeError(tf_session.TF_Message(status)) + self._opened = True + finally: + tf_session.TF_DeleteStatus(status) + + self._current_version = self._graph.version + + return tf_session.TF_Run(self._session, feed_dict, fetch_list, + target_list) + + except tf_session.StatusNotOK as e: + e_type, e_value, e_traceback = sys.exc_info() + m = BaseSession._NODEDEF_NAME_RE.search(e.error_message) + if m is not None: + node_name = m.group(1) + node_def = None + try: + op = self._graph.get_operation_by_name(node_name) + node_def = op.node_def + except KeyError: + op = None + # pylint: disable=protected-access + raise errors._make_specific_exception(node_def, op, e.error_message, + e.code) + # pylint: enable=protected-access + raise e_type, e_value, e_traceback + + +class Session(BaseSession): + """A class for running TensorFlow operations. + + A `Session` object encapsulates the environment in which `Operation` + objects are executed, and `Tensor` objects are evaluated. For + example: + + ```python + # Build a graph. + a = tf.constant(5.0) + b = tf.constant(6.0) + c = a * b + + # Launch the graph in a session. + sess = tf.Session() + + # Evaluate the tensor `c`. + print sess.run(c) + ``` + + A session may own resources, such as + [variables](state_ops.md#Variable), [queues](io_ops.md#QueueBase), + and [readers](io_ops.md#ReaderBase). It is important to release + these resources when they are no longer required. To do this, either + invoke the [`close()`](#Session.close) method on the session, or use + the session as a context manager. The following two examples are + equivalent: + + ```python + # Using the `close()` method. + sess = tf.Session() + sess.run(...) + sess.close() + + # Using the context manager. + with tf.Session() as sess: + sess.run(...) + ``` + + The [`ConfigProto`] + (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto) + protocol buffer exposes various configuration options for a + session. For example, to create a session that uses soft constraints + for device placement, and log the resulting placement decisions, + create a session as follows: + + ```python + # Launch the graph in a session that allows soft device placement and + # logs the placement decisions. + sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, + log_device_placement=True)) + ``` + + @@__init__ + @@run + @@close + + @@graph + + @@as_default + + """ + + def __init__(self, target='', graph=None, config=None): + """Creates a new TensorFlow session. + + If no `graph` argument is specified when constructing the session, + the default graph will be launched in the session. If you are + using more than one graph (created with `tf.Graph()` in the same + process, you will have to use different sessions for each graph, + but each graph can be used in multiple sessions. In this case, it + is often clearer to pass the graph to be launched explicitly to + the session constructor. + + Args: + target: (Optional.) The execution engine to connect to. + Defaults to using an in-process engine. At present, no value + other than the empty string is supported. + graph: (Optional.) The `Graph` to be launched (described above). + config: (Optional.) A [`ConfigProto`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto) + protocol buffer with configuration options for the session. + + """ + super(Session, self).__init__(target, graph, config=config) + self._context_managers = [self.graph.as_default(), self.as_default()] + + def __enter__(self): + for context_manager in self._context_managers: + context_manager.__enter__() + return self + + def __exit__(self, exec_type, exec_value, exec_tb): + if exec_type is errors.OpError: + logging.error('Session closing due to OpError: %s', (exec_value,)) + + for context_manager in reversed(self._context_managers): + context_manager.__exit__(exec_type, exec_value, exec_tb) + + self.close() + + +class InteractiveSession(BaseSession): + """A TensorFlow `Session` for use in interactive contexts, such as a shell. + + In some cases, such as interactive shells and IPython notebooks, it is + useful to be able to define a `Session` without using a with block: this + style enables statements to be executed immediately, rather than at the + termination of the block. In that case, it must be closed using + `Session.close()`. For example: + + ```python + sess = InteractiveSession() + a = tf.constant(5.0) + b = tf.constant(6.0) + c = a * b + print c.run() + sess.close() + ``` + + @@__init__ + @@close + """ + + def __init__(self, target='', graph=None): + """Initializes an `InteractiveSession` object similar to `Session`. + + Args: + target: Optional. The TensorFlow execution engine to connect to. + graph: Optional. The `Graph` object to be used. If this argument is None, + the default graph will be used. + """ + super(InteractiveSession, self).__init__(target, graph) + self._default_session = self.as_default() + self._default_session.__enter__() + self._explicit_graph = graph + if self._explicit_graph is not None: + self._default_graph = graph.as_default() + self._default_graph.__enter__() + + def close(self): + """Closes an `InteractiveSession`.""" + super(InteractiveSession, self).close() + if self._explicit_graph is not None: + self._default_graph.__exit__(None, None, None) + self._default_session.__exit__(None, None, None) diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py new file mode 100644 index 0000000000..4492840dcf --- /dev/null +++ b/tensorflow/python/client/session_test.py @@ -0,0 +1,555 @@ +"""Tests for tensorflow.python.client.session.Session.""" +import threading +import time + +import tensorflow.python.platform + +import numpy as np + +from tensorflow.core.framework import config_pb2 +from tensorflow.core.lib.core import error_codes_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import test_util +from tensorflow.python.framework import types +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import constant_op +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +# NOTE(mrry): Dummy shape registration for op used in the tests. +ops.RegisterShape('ConstructionFails')(None) + + +class SessionTest(test_util.TensorFlowTestCase): + + def testUseExistingGraph(self): + with ops.Graph().as_default() as g, ops.device('/cpu:0'): + a = constant_op.constant(6.0, shape=[1, 1]) + b = constant_op.constant(7.0, shape=[1, 1]) + c = math_ops.matmul(a, b, name='matmul') + with session.Session(graph=g): + result = c.eval() + self.assertAllEqual(result, [[42.0]]) + + def testUseDefaultGraph(self): + with ops.Graph().as_default(), ops.device('/cpu:0'): + a = constant_op.constant(6.0, shape=[1, 1]) + b = constant_op.constant(7.0, shape=[1, 1]) + c = math_ops.matmul(a, b, name='matmul') + with session.Session(): + result = c.eval() + self.assertAllEqual(result, [[42.0]]) + + def testCreate(self): + with session.Session(): + inp = constant_op.constant(10.0, name='W1') + copy = array_ops.identity(inp) + # Test with feed. + # TODO(mrry): Investigate why order='F' didn't work. + arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C') + copy_val = copy.eval({'W1:0': arr}) + self.assertAllEqual(arr, copy_val) + # Test without feed. + copy_val = copy.eval() + self.assertAllEqual(np.asarray(10.0, dtype=np.float32), copy_val) + + def testManyCPUs(self): + # TODO(keveman): Implement ListDevices and test for the number of + # devices returned by ListDevices. + with session.Session( + config=config_pb2.ConfigProto(device_count={'CPU': 2})): + inp = constant_op.constant(10.0, name='W1') + self.assertAllEqual(inp.eval(), 10.0) + + def testErrorsReported(self): + with session.Session() as s: + constant_op.constant(10.0, name='W1') + with self.assertRaises(ValueError): + s.run('foo:0') + + def testErrorPayload(self): + with session.Session(): + a = array_ops.placeholder(types.float32) + with self.assertRaisesOpError(lambda e: e.op == a.op): + a.eval() + + def testOpConstructionErrorPayload(self): + with session.Session(): + failing_op = ops.get_default_graph().create_op( + 'ConstructionFails', [], [], name='f') + + def exc_predicate(e): + return (e.op == failing_op + and e.error_code == error_codes_pb2.INVALID_ARGUMENT) + with self.assertRaisesOpError(exc_predicate): + failing_op.run() + + def testErrorBasedOn(self): + with session.Session() as sess: + a = constant_op.constant(0.0, shape=[2, 3]) + # NOTE(mrry): The original_op is nonsense, but used here to test that the + # errors are reported correctly. + # pylint: disable=protected-access + with sess.graph._original_op(a.op): + b = array_ops.identity(a, name='id') + with sess.graph._original_op(b.op): + c = array_ops.placeholder(types.float32) + # pylint: enable=protected-access + + def exc_predicate(e): + return (e.op == c.op + and e.op._original_op == b.op + and e.op._original_op._original_op == a.op) + with self.assertRaisesOpError(exc_predicate): + c.eval() + + def testFetchTensorObject(self): + with session.Session() as s: + a = constant_op.constant(1.0, shape=[1, 2]) + b = constant_op.constant(2.0, shape=[2, 3]) + c = math_ops.matmul(a, b) + results_with_list = s.run([c]) + self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0]) + results_with_single = s.run(c) + self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single) + results_with_get = c.eval() + self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get) + a_val, b_val = s.run([a, b]) # Test multiple fetches. + self.assertAllEqual([[1.0, 1.0]], a_val) + self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val) + + def testFetchScalar(self): + with session.Session() as s: + for scalar in np.int32, np.int64, np.float32, np.float64: + x = scalar(7) + y = scalar(8) + tf_x = constant_op.constant(x, shape=[]) + tf_y = constant_op.constant(y) + tf_xy = math_ops.add(tf_x, tf_y) + # Single fetch + xy = s.run(tf_xy) + self.assertEqual(scalar, type(xy)) + self.assertEqual(x + y, xy) + # List fetch + xy, = s.run([tf_xy]) + self.assertEqual(scalar, type(xy)) + self.assertEqual(x + y, xy) + + def testFetchOperationObject(self): + with session.Session() as s: + a = constant_op.constant(1.0, shape=[1, 2]) + v = variables.Variable(a, name='testFetchOperationObject_v') + s.run(v.initializer) + v_val = s.run(v) + self.assertAllEqual([[1.0, 1.0]], v_val) + + def testFetchSparseTensor(self): + with session.Session() as s: + indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) + values = np.array([1.0, 2.0]).astype(np.float32) + shape = np.array([7, 9, 2]).astype(np.int64) + sp = ops.SparseTensor( + constant_op.constant(indices), + constant_op.constant(values), + constant_op.constant(shape)) + # Single fetch, use as tuple + sp_out = s.run(sp) + indices_out, values_out, shape_out = sp_out + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(values_out, values) + self.assertAllEqual(shape_out, shape) + # Single fetch, use as SparseTensorValue + sp_out = s.run(sp) + self.assertAllEqual(sp_out.indices, indices) + self.assertAllEqual(sp_out.values, values) + self.assertAllEqual(sp_out.shape, shape) + # Tuple fetch, use as tuple + indices_out, values_out, shape_out = s.run(sp) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(values_out, values) + self.assertAllEqual(shape_out, shape) + # List fetch, use as tuple + (indices_out, values_out, shape_out), = s.run([sp]) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(values_out, values) + self.assertAllEqual(shape_out, shape) + # List fetch, use as SparseTensorValue + sp_out, = s.run([sp]) + self.assertAllEqual(sp_out.indices, indices) + self.assertAllEqual(sp_out.values, values) + self.assertAllEqual(sp_out.shape, shape) + + def testFeedSparseTensor(self): + with session.Session() as s: + indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) + values = np.array([1.0, 2.0]).astype(np.float32) + shape = np.array([7, 9, 2]).astype(np.int64) + sp = ops.SparseTensor( + array_ops.placeholder(dtype=np.int64, shape=(2, 3)), + array_ops.placeholder(dtype=np.float32, shape=(2,)), + array_ops.placeholder(dtype=np.int64, shape=(3,)),) + sp_indices = array_ops.identity(sp.indices) + sp_values = array_ops.identity(sp.values) + sp_shape = array_ops.identity(sp.shape) + sp2 = ops.SparseTensor(sp_indices, sp_values, sp_shape) + # Feed with tuple + indices_out, values_out, shape_out = s.run( + [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)}) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(values_out, values) + self.assertAllEqual(shape_out, shape) + # Feed with SparseTensorValue + indices_out, values_out, shape_out = s.run( + [sp_indices, sp_values, sp_shape], + {sp: ops.SparseTensorValue(indices, values, shape)}) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(values_out, values) + self.assertAllEqual(shape_out, shape) + # Feed with SparseTensorValue, fetch SparseTensorValue + sp2_out = s.run(sp2, {sp: ops.SparseTensorValue(indices, values, shape)}) + self.assertAllEqual(sp2_out.indices, indices) + self.assertAllEqual(sp2_out.values, values) + self.assertAllEqual(sp2_out.shape, shape) + + def testExtendWithStatelessOperations(self): + with session.Session() as s: + a = constant_op.constant(1.0, shape=[1, 2]) + b = constant_op.constant(2.0, shape=[2, 3]) + c = math_ops.matmul(a, b) + c_val = s.run(c) + self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) + d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) + e = math_ops.matmul(c, d) + # Extend will happen here. + e_val = s.run(e) + self.assertAllEqual([[24.0]], e_val) + + def testExtendWithStatefulOperations(self): + with session.Session() as s: + a = constant_op.constant(1.0, shape=[1, 2]) + b = constant_op.constant(2.0, shape=[2, 3]) + c = math_ops.matmul(a, b) + v = variables.Variable(c, name='testExtendWithStatefulOperations_v') + v.initializer.run() + v_val = v.eval() + self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) + d = constant_op.constant(3.0, shape=[2, 3]) + e = math_ops.matmul(a, d) + assign_e_to_v = state_ops.assign(v, e) + # Extend will happen here. + e_val = e.eval() + self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) + v_val = v.eval() + self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) + s.run(assign_e_to_v) + v_val = v.eval() + self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) + + def testExtendWithGroupBy(self): + with session.Session() as s: + a = constant_op.constant(1.0, shape=[1, 2]) + p = variables.Variable(a, name='testExtendWithGroupBy_p') + a_val = a.eval() # Force an Extend after this op. + self.assertAllEqual([[1.0, 1.0]], a_val) + + b = constant_op.constant(2.0, shape=[1, 2]) + q = variables.Variable(b, name='testExtendWithGroupBy_q') + # Extend will happen here. + init = control_flow_ops.group(p.initializer, q.initializer) + s.run(init) + p_val, q_val = s.run([p, q]) + + self.assertAllEqual([[1.0, 1.0]], p_val) + self.assertAllEqual([[2.0, 2.0]], q_val) + + def testTensorGetMethod(self): + with session.Session(): + a = constant_op.constant(1.0, shape=[1, 2]) + b = constant_op.constant(2.0, shape=[2, 3]) + c = math_ops.matmul(a, b) + + c_val = c.eval() + self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) + + fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]}) + self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val) + + def testOperationRunMethod(self): + with session.Session(): + a = constant_op.constant(1.0, shape=[1, 2]) + b = constant_op.constant(2.0, shape=[1, 2], name='b') + v = variables.Variable(a, a.dtype) + assign_a_to_v = state_ops.assign(v, a) + + assign_a_to_v.eval() + + v_val = v.eval() + self.assertAllEqual([[1.0, 1.0]], v_val) + + assign_b_to_v = state_ops.assign(v, b) + + assign_b_to_v.eval() + v_val = v.eval() + self.assertAllEqual([[2.0, 2.0]], v_val) + + assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]}) + v_val = v.eval() + self.assertAllEqual([[3.0, 3.0]], v_val) + + def testDefaultGraph(self): + with session.Session() as s: + self.assertEqual(ops.get_default_graph(), s.graph) + a = constant_op.constant(1.0, shape=[1, 2]) + b = constant_op.constant(2.0, shape=[2, 3]) + self.assertEqual(ops.get_default_graph(), a.graph) + self.assertEqual(ops.get_default_graph(), b.graph) + c = math_ops.matmul(a, b) + v = variables.Variable(c, name='testDefaultGraph_v') + v.initializer.run() + v_val = v.eval() + self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) + d = constant_op.constant(3.0, shape=[2, 3]) + e = math_ops.matmul(a, d) + assign_e_to_v = state_ops.assign(v, e) + e_val = e.eval() + self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) + v_val = v.eval() + self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) + s.run(assign_e_to_v) + v_val = v.eval() + self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) + self.assertEqual(ops.get_default_graph(), s.graph) + + def _testDefaultGraphInThread(self, constructed_event, continue_event, i): + with session.Session() as s: + self.assertEqual(ops.get_default_graph(), s.graph) + a = constant_op.constant(1.0, shape=[1, 2]) + b = constant_op.constant(2.0, shape=[2, 3]) + c = math_ops.matmul(a, b) + v = variables.Variable(c, name='var_%d' % i) + + # Block here until all threads have constructed their graph. + constructed_event.set() + continue_event.wait() + + assign_c_to_v = state_ops.assign(v, c) + v.initializer.run() + assign_c_to_v.eval() + v_val = v.eval() + self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) + d = constant_op.constant(3.0, shape=[2, 3]) + e = math_ops.matmul(a, d) + assign_e_to_v = state_ops.assign(v, e) + e_val = e.eval() + self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) + v_val = v.eval() + self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) + s.run(assign_e_to_v) + v_val = v.eval() + self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) + self.assertEqual(ops.get_default_graph(), s.graph) + + def testDefaultGraphWithThreads(self): + # Fork ten threads that use their thread-local default graph. + threads = [] + constructed_events = [threading.Event() for _ in range(10)] + continue_event = threading.Event() + for i, constructed_event in enumerate(constructed_events): + t = self.checkedThread(target=self._testDefaultGraphInThread, + args=(constructed_event, continue_event, i)) + threads.append(t) + for t in threads: + t.start() + for constructed_event in constructed_events: + constructed_event.wait() + continue_event.set() + for t in threads: + t.join() + + def testParallelRun(self): + with session.Session() as sess: + c = constant_op.constant(5.0) + ev = threading.Event() + + def run_step(): + ev.wait() + val = c.eval(session=sess) + self.assertEqual(val, 5.0) + threads = [self.checkedThread(target=run_step) for _ in range(100)] + for t in threads: + t.start() + ev.set() + for t in threads: + t.join() + + def testRunFeedDict(self): + with session.Session() as s: + x = array_ops.zeros([2]) + + y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)}) + self.assertAllEqual(y, 2 * np.ones(2)) + + y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)}) + self.assertAllEqual(y, 2 * np.ones(2)) + + y = s.run(2 * x, feed_dict={x: [1, 1]}) + assert (y == 2 * np.ones(2)).all() + + def testGraphDef(self): + with session.Session() as sess: + self.assertProtoEquals('', sess.graph_def) + c = constant_op.constant(5.0, name='c') + self.assertEquals(len(sess.graph_def.node), 1) + d = constant_op.constant(6.0, name='d') + self.assertEquals(len(sess.graph_def.node), 2) + self.assertAllEqual(c.eval(), 5.0) + self.assertAllEqual(d.eval(), 6.0) + e = constant_op.constant(7.0, name='e') + self.assertEquals(len(sess.graph_def.node), 3) + self.assertAllEqual(e.eval(), 7.0) + + def testUseAfterClose(self): + with session.Session() as sess: + c = constant_op.constant(5.0) + self.assertAllEqual(sess.run(c), 5.0) + with self.assertRaisesWithPredicateMatch( + RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)): + sess.run(c) + + def testUseAfterCloseConcurrent(self): + with session.Session() as sess: + c = constant_op.constant(5.0) + self.assertAllEqual(sess.run(c), 5.0) + + def update_thread(): + with self.assertRaisesWithPredicateMatch( + RuntimeError, + lambda e: 'Attempted to use a closed Session.' in str(e)): + while True: + sess.run(c) + t = threading.Thread(target=update_thread) + t.start() + time.sleep(0.1) + sess.close() + t.join() + + def testNotEntered(self): + # pylint: disable=protected-access + self.assertEqual(ops._default_session_stack.get_default(), None) + # pylint: enable=protected-access + with ops.device('/cpu:0'): + sess = session.Session() + c_1 = constant_op.constant(5.0) + with sess.graph.as_default(): + c_2 = constant_op.constant(5.0) + self.assertEqual(c_1.graph, c_2.graph) + self.assertEqual(sess.run(c_2), 5.0) + with self.assertRaisesWithPredicateMatch( + ValueError, lambda e: 'No default session is registered.' in str(e)): + c_2.eval() + + def testInteractive(self): + with ops.device('/cpu:0'): + sess = session.InteractiveSession() + a = constant_op.constant(1.0, shape=[1, 2]) + b = constant_op.constant(2.0, shape=[2, 3]) + c = math_ops.matmul(a, b) + self.assertAllEqual([[4.0, 4.0, 4.0]], c.eval()) + d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) + e = math_ops.matmul(c, d) + self.assertAllEqual([[24.0]], e.eval()) + sess.close() + + def testSharedGraph(self): + with ops.Graph().as_default() as g, ops.device('/cpu:0'): + a = constant_op.constant(1.0, shape=[1, 2]) + b = constant_op.constant(2.0, shape=[2, 3]) + c = math_ops.matmul(a, b) + + with session.Session(graph=g) as sess1: + with session.Session(graph=g) as sess2: + self.assertAllEqual(sess1.run(c), sess2.run(c)) + + def testDuplicatedInputs(self): + with session.Session() as sess: + a = constant_op.constant(1.0, shape=[1, 2]) + b = constant_op.constant(2.0, shape=[1, 3]) + a_val, b_val, a2_val = sess.run([a, b, a]) + self.assertAllEqual(a_val, [[1.0, 1.0]]) + self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]]) + self.assertAllEqual(a2_val, [[1.0, 1.0]]) + + def testFeedAndFetch(self): + with session.Session(): + for dtype in [types.float32, + types.float64, + types.int32, + types.uint8, + types.int16, + types.int8, + types.int64, + types.bool, + types.complex64]: + for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: + np_dtype = dtype.as_numpy_dtype + + feed_t = array_ops.placeholder(dtype=dtype, shape=shape) + out_t = array_ops.identity(feed_t) + + np_array = np.random.randint(-10, 10, shape) + + if dtype == types.bool: + np_array = np_array > 0 + elif dtype == types.complex64: + np_array = np.sqrt(np_array.astype(np_dtype)) + else: + np_array = np_array.astype(np_dtype) + + self.assertAllEqual(np_array, + out_t.eval(feed_dict={feed_t: np_array})) + + def testStringFetch(self): + with session.Session(): + for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: + size = 1 + for s in shape: + size *= s + c_list = np.array([str(i) for i in xrange(size)], + dtype=np.object).reshape(shape) if size > 0 else [] + c = constant_op.constant(c_list) + self.assertAllEqual(c.eval(), c_list) + + def testStringFeed(self): + with session.Session(): + for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: + size = 1 + for s in shape: + size *= s + c_list = np.array([str(i) for i in xrange(size)], + dtype=np.object).reshape(shape) + feed_t = array_ops.placeholder(dtype=types.string, shape=shape) + c = array_ops.identity(feed_t) + self.assertAllEqual(c.eval(feed_dict={feed_t: c_list}), c_list) + + def testStringFeedWithNullCharacters(self): + with session.Session(): + c_list = ['\n\x01\x00', '\n\x00\x01'] + feed_t = array_ops.placeholder(dtype=types.string, shape=[2]) + c = array_ops.identity(feed_t) + out = c.eval(feed_dict={feed_t: c_list}) + self.assertEqual(c_list[0], out[0]) + self.assertEqual(c_list[1], out[1]) + + def testInvalidTargetFails(self): + with self.assertRaises(RuntimeError): + session.Session("INVALID_TARGET") + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/python/client/tensorflow_server.i b/tensorflow/python/client/tensorflow_server.i new file mode 100644 index 0000000000..65b3826961 --- /dev/null +++ b/tensorflow/python/client/tensorflow_server.i @@ -0,0 +1,16 @@ +%include "tensorflow/python/platform/base.i" +%import(module="tensorflow.python.pywrap_tensorflow") "tensorflow/python/lib/core/status.i" + +%{ +#include "tensorflow/core/public/tensorflow_server.h" +%} + +%ignoreall + +%unignore tensorflow; +%unignore tensorflow::LaunchTensorFlow; + +%include "tensorflow/core/public/tensorflow_server.h" + +%unignoreall + diff --git a/tensorflow/python/client/test_construction_fails_op.cc b/tensorflow/python/client/test_construction_fails_op.cc new file mode 100644 index 0000000000..47b2b5b49c --- /dev/null +++ b/tensorflow/python/client/test_construction_fails_op.cc @@ -0,0 +1,22 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +REGISTER_OP("ConstructionFails"); + +class ConstructionFailsOp : public OpKernel { + public: + explicit ConstructionFailsOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES(ctx, false, + errors::InvalidArgument("Failure during construction.")); + } + + void Compute(OpKernelContext* ctx) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("ConstructionFails").Device(DEVICE_CPU), + ConstructionFailsOp); + +} // end namespace tensorflow diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i new file mode 100644 index 0000000000..30e80f779f --- /dev/null +++ b/tensorflow/python/client/tf_session.i @@ -0,0 +1,235 @@ +%include "tensorflow/python/platform/base.i" + +%{ + +#include "numpy/arrayobject.h" + +#include "tensorflow/python/client/tf_session_helper.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/status.h" + +%} + +// Implements the StatusNotOK exception. +%import(module="tensorflow.python.pywrap_tensorflow") "tensorflow/python/lib/core/status.i" + +// Required to use PyArray_* functions. +%include "tensorflow/python/platform/numpy.i" +%init %{ +import_array(); +%} + +// Release the Python GIL for the duration of most methods. +%exception { + Py_BEGIN_ALLOW_THREADS; + $action + Py_END_ALLOW_THREADS; +} + +// Proto input arguments to C API functions are passed as a (const +// void*, size_t) pair. In Python, typemap these to a single string +// argument. +%typemap(in) (const void* proto, size_t proto_len) { + char* c_string; + Py_ssize_t py_size; + if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { + // Python has raised an error (likely TypeError or UnicodeEncodeError). + SWIG_fail; + } + $1 = static_cast<void*>(c_string); + $2 = static_cast<size_t>(py_size); +} + +//////////////////////////////////////////////////////////////////////////////// +// BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper() +//////////////////////////////////////////////////////////////////////////////// + +// The wrapper takes a vector of pairs of feed names and feed +// values. In Python this is represented as dictionary mapping strings +// to numpy arrays. +%typemap(in) const tensorflow::FeedVector& inputs ( + tensorflow::FeedVector temp, + tensorflow::Safe_PyObjectPtr temp_string_list(tensorflow::make_safe(nullptr)), + tensorflow::Safe_PyObjectPtr temp_array_list(tensorflow::make_safe(nullptr))) { + if (!PyDict_Check($input)) { + SWIG_fail; + } + + temp_string_list = tensorflow::make_safe(PyList_New(0)); + if (!temp_string_list) { + SWIG_fail; + } + temp_array_list = tensorflow::make_safe(PyList_New(0)); + if (!temp_array_list) { + SWIG_fail; + } + + PyObject* key; + PyObject* value; + Py_ssize_t pos = 0; + while (PyDict_Next($input, &pos, &key, &value)) { + const char* key_string = PyString_AsString(key); + if (!key_string) { + SWIG_fail; + } + + // The ndarray must be stored as contiguous bytes in C (row-major) order. + PyObject* array_object = PyArray_FromAny( + value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr); + if (!array_object) { + SWIG_fail; + } + PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_object); + + // Keep a reference to the key and the array, in case the incoming dict is + // modified, and/or to avoid leaking references on failure. + if (PyList_Append(temp_string_list.get(), key) == -1) { + SWIG_fail; + } + if (PyList_Append(temp_array_list.get(), array_object) == -1) { + SWIG_fail; + } + + temp.push_back(std::make_pair(key_string, array)); + } + $1 = &temp; +} + +// The wrapper also takes a list of fetch and target names. In Python this is +// represented as a list of strings. +%typemap(in) const tensorflow::NameVector& ( + tensorflow::NameVector temp, + tensorflow::Safe_PyObjectPtr temp_string_list(tensorflow::make_safe(nullptr))) { + if (!PyList_Check($input)) { + SWIG_fail; + } + + Py_ssize_t len = PyList_Size($input); + + temp_string_list = tensorflow::make_safe(PyList_New(len)); + if (!temp_string_list) { + SWIG_fail; + } + + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* elem = PyList_GetItem($input, i); + if (!elem) { + SWIG_fail; + } + + // Keep a reference to the string in case the incoming list is modified. + PyList_SET_ITEM(temp_string_list.get(), i, elem); + Py_INCREF(elem); + + const char* fetch_name = PyString_AsString(elem); + if (!fetch_name) { + PyErr_SetString(PyExc_TypeError, + "a fetch or target name was not a string"); + SWIG_fail; + } + + // TODO(mrry): Avoid copying the fetch name in, if this impacts performance. + temp.push_back(fetch_name); + } + $1 = &temp; +} + + +// The wrapper has two outputs: a tensorflow::Status, and a vector of +// PyObjects containing the fetch results (iff the status is OK). Since +// the interpretation of the vector depends on the status, we define +// them as two consecutive out arguments, so that they can be accessed +// together in a typemap. + +// Define temporaries for the argout outputs. +%typemap(in, numinputs=0) tensorflow::Status* out_status ( + tensorflow::Status temp) { + $1 = &temp; +} +%typemap(in, numinputs=0) tensorflow::PyObjectVector* out_values ( + tensorflow::PyObjectVector temp) { + $1 = &temp; +} + +// Raise a StatusNotOK exception if the out_status is not OK; +// otherwise build a Python list of outputs and return it. +%typemap(argout, fragment="StatusNotOK") ( + tensorflow::Status* out_status, tensorflow::PyObjectVector* out_values) { + if (!$1->ok()) { + RaiseStatusNotOK(*$1, $descriptor(tensorflow::Status*)); + SWIG_fail; + } else { + tensorflow::Safe_PyObjectVector out_values_safe; + for (int i = 0; i < $2->size(); ++i) { + out_values_safe.emplace_back(tensorflow::make_safe($2->at(i))); + } + + $result = PyList_New($2->size()); + if (!$result) { + SWIG_fail; + } + + for (int i = 0; i < $2->size(); ++i) { + PyList_SET_ITEM($result, i, $2->at(i)); + out_values_safe[i].release(); + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// END TYPEMAPS FOR tensorflow::TF_Run_wrapper() +//////////////////////////////////////////////////////////////////////////////// + + + +// Include the functions from tensor_c_api.h, except TF_Run. +%ignoreall +%unignore TF_Code; +%unignore TF_Status; +%unignore TF_NewStatus; +%unignore TF_DeleteStatus; +%unignore TF_GetCode; +%unignore TF_Message; +%unignore TF_SessionOptions; +%rename("_TF_SetTarget") TF_SetTarget; +%rename("_TF_SetConfig") TF_SetConfig; +%rename("_TF_NewSessionOptions") TF_NewSessionOptions; +%unignore TF_DeleteSessionOptions; +%unignore TF_NewSession; +%unignore TF_CloseSession; +%unignore TF_DeleteSession; +%unignore TF_ExtendGraph; +%include "tensorflow/core/public/tensor_c_api.h" +%ignoreall + +%insert("python") %{ + def TF_NewSessionOptions(target=None, config=None): + opts = _TF_NewSessionOptions() + if target is not None: + _TF_SetTarget(opts, target) + if config is not None: + from tensorflow.core.framework import config_pb2 + if not isinstance(config, config_pb2.ConfigProto): + raise TypeError("Expected config_pb2.ConfigProto, " + "but got %s" % type(config)) + status = TF_NewStatus() + config_str = config.SerializeToString() + _TF_SetConfig(opts, config_str, len(config_str), status) + if TF_GetCode(status) != 0: + raise ValueError(TF_Message(status)) + return opts +%} + +// Include the wrapper for TF_Run from tf_session_helper.h. + +// The %exception block above releases the Python GIL for the length +// of each wrapped method. We disable this behavior for TF_Run +// because it uses the Python allocator. +%noexception tensorflow::TF_Run_wrapper; +%rename(TF_Run) tensorflow::TF_Run_wrapper; +%unignore tensorflow; +%unignore TF_Run; + +%include "tensorflow/python/client/tf_session_helper.h" + +%unignoreall diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc new file mode 100644 index 0000000000..06483da87b --- /dev/null +++ b/tensorflow/python/client/tf_session_helper.cc @@ -0,0 +1,518 @@ +#include "tensorflow/python/client/tf_session_helper.h" + +#include <cstring> + +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +namespace { + +// Container types for the various temporary values used internally in +// the wrapper. + +// A TF_TensorVector is a vector of borrowed pointers to TF_Tensors. +typedef gtl::InlinedVector<TF_Tensor*, 8> TF_TensorVector; + +// Safe containers for (an) owned TF_Tensor(s). On destruction, the +// tensor will be deleted by TF_DeleteTensor. +typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> + Safe_TF_TensorPtr; +typedef std::vector<Safe_TF_TensorPtr> Safe_TF_TensorVector; +Safe_TF_TensorPtr make_safe(TF_Tensor* tensor) { + return Safe_TF_TensorPtr(tensor, TF_DeleteTensor); +} + +// Safe container for an owned TF_Status. On destruction, the status +// will be deleted by TF_DeleteStatus. +typedef std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> + Safe_TF_StatusPtr; +Safe_TF_StatusPtr make_safe(TF_Status* status) { + return Safe_TF_StatusPtr(status, TF_DeleteStatus); +} + +Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr, + TF_DataType* out_tf_datatype) { + PyObject* key; + PyObject* value; + Py_ssize_t pos = 0; + if (PyDict_Next(descr->fields, &pos, &key, &value)) { + const char* key_string = PyString_AsString(key); + if (!key_string) { + return errors::Internal("Corrupt numpy type descriptor"); + } + tensorflow::string key = key_string; + // The typenames here should match the field names in the custom struct + // types constructed in test_util.py. + // TODO(mrry,keveman): Investigate Numpy type registration to replace this + // hard-coding of names. + if (key == "quint8") { + *out_tf_datatype = TF_QUINT8; + } else if (key == "qint8") { + *out_tf_datatype = TF_QINT8; + } else if (key == "qint32") { + *out_tf_datatype = TF_QINT32; + } else { + return errors::Internal("Unsupported numpy data type"); + } + return Status::OK(); + } + return errors::Internal("Unsupported numpy data type"); +} + +Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array, + TF_DataType* out_tf_datatype) { + int pyarray_type = PyArray_TYPE(array); + PyArray_Descr* descr = array->descr; + switch (pyarray_type) { + case NPY_FLOAT32: + *out_tf_datatype = TF_FLOAT; + break; + case NPY_FLOAT64: + *out_tf_datatype = TF_DOUBLE; + break; + case NPY_INT32: + *out_tf_datatype = TF_INT32; + break; + case NPY_UINT8: + *out_tf_datatype = TF_UINT8; + break; + case NPY_INT16: + *out_tf_datatype = TF_INT16; + break; + case NPY_INT8: + *out_tf_datatype = TF_INT8; + break; + case NPY_INT64: + *out_tf_datatype = TF_INT64; + break; + case NPY_BOOL: + *out_tf_datatype = TF_BOOL; + break; + case NPY_COMPLEX64: + *out_tf_datatype = TF_COMPLEX; + break; + case NPY_OBJECT: + *out_tf_datatype = TF_STRING; + break; + case NPY_VOID: + // Quantized types are currently represented as custom struct types. + // PyArray_TYPE returns NPY_VOID for structs, and we should look into + // descr to derive the actual type. + return PyArrayDescr_to_TF_DataType(descr, out_tf_datatype); + default: + // TODO(mrry): Support these. + return errors::Internal("Unsupported feed type"); + } + return Status::OK(); +} + +Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype, + int* out_pyarray_type) { + switch (tf_datatype) { + case TF_FLOAT: + *out_pyarray_type = NPY_FLOAT32; + break; + case TF_DOUBLE: + *out_pyarray_type = NPY_FLOAT64; + break; + case TF_INT32: + *out_pyarray_type = NPY_INT32; + break; + case TF_UINT8: + *out_pyarray_type = NPY_UINT8; + break; + case TF_INT16: + *out_pyarray_type = NPY_INT16; + break; + case TF_INT8: + *out_pyarray_type = NPY_INT8; + break; + case TF_INT64: + *out_pyarray_type = NPY_INT64; + break; + case TF_BOOL: + *out_pyarray_type = NPY_BOOL; + break; + case TF_COMPLEX: + *out_pyarray_type = NPY_COMPLEX64; + break; + case TF_STRING: + *out_pyarray_type = NPY_OBJECT; + break; + // TODO(keveman): These should be changed to NPY_VOID, and the type used for + // the resulting numpy array should be the custom struct types that we + // expect for quantized types. + case TF_QINT8: + *out_pyarray_type = NPY_INT8; + break; + case TF_QUINT8: + *out_pyarray_type = NPY_UINT8; + break; + case TF_QINT32: + *out_pyarray_type = NPY_INT32; + break; + case TF_BFLOAT16: + *out_pyarray_type = NPY_UINT16; + break; + default: + return errors::Internal("Unsupported fetch type"); + } + return Status::OK(); +} + +// Iterate over the string array 'array', extract the ptr and len of each string +// element and call f(ptr, len). +template <typename F> +Status PyStringArrayMap(PyArrayObject* array, F f) { + Safe_PyObjectPtr iter = tensorflow::make_safe( + PyArray_IterNew(reinterpret_cast<PyObject*>(array))); + while (PyArray_ITER_NOTDONE(iter.get())) { + auto item = tensorflow::make_safe( + PyArray_GETITEM(array, PyArray_ITER_DATA(iter.get()))); + if (!item.get()) { + return errors::Internal("Unable to get element from the feed."); + } + char* ptr; + Py_ssize_t len; + int success = PyString_AsStringAndSize(item.get(), &ptr, &len); + if (success != 0) { + return errors::Internal("Unable to get element from the feed."); + } + f(ptr, len); + PyArray_ITER_NEXT(iter.get()); + } + return Status::OK(); +} + +// Encode the strings in 'array' into a contiguous buffer and return the base of +// the buffer. The caller takes ownership of the buffer. +Status EncodePyStringArray(PyArrayObject* array, tensorflow::int64 nelems, + size_t* size, void** buffer) { + // Compute bytes needed for encoding. + *size = 0; + TF_RETURN_IF_ERROR( + PyStringArrayMap(array, [&size](char* ptr, Py_ssize_t len) { + *size += sizeof(tensorflow::uint64) + + tensorflow::core::VarintLength(len) + len; + })); + // Encode all strings. + std::unique_ptr<char[]> base_ptr(new char[*size]); + char* base = base_ptr.get(); + char* data_start = base + sizeof(tensorflow::uint64) * nelems; + char* dst = data_start; // Where next string is encoded. + tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base); + + TF_RETURN_IF_ERROR(PyStringArrayMap( + array, [&base, &data_start, &dst, &offsets](char* ptr, Py_ssize_t len) { + *offsets = (dst - data_start); + offsets++; + dst = tensorflow::core::EncodeVarint64(dst, len); + memcpy(dst, ptr, len); + dst += len; + })); + CHECK_EQ(dst, base + *size); + *buffer = base_ptr.release(); + return Status::OK(); +} + +// Determine the pointer and offset of the string at offset 'i' in the string +// tensor 'src', whose total length is 'num_elements'. +static Status TF_StringTensor_GetPtrAndLen(const TF_Tensor* src, + tensorflow::int64 num_elements, + tensorflow::int64 i, + const char** ptr, + tensorflow::uint64* len) { + const char* input = reinterpret_cast<const char*>(TF_TensorData(src)); + const size_t src_size = TF_TensorByteSize(src); + const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; + const char* limit = input + src_size; + tensorflow::uint64 offset = + reinterpret_cast<const tensorflow::uint64*>(input)[i]; + const char* p = + tensorflow::core::GetVarint64Ptr(data_start + offset, limit, len); + if (offset >= (limit - data_start) || !p || (*len > (limit - p))) { + return errors::InvalidArgument("Malformed TF_STRING tensor; element ", i, + " out of range"); + } + *ptr = p; + return Status::OK(); +} + +// Copy the string at offset 'i' in the (linearized) string tensor 'tensor' into +// 'pyarray' at offset pointed by the 'i_ptr' iterator. +static Status CopyStringToPyArrayElement(PyArrayObject* pyarray, void* i_ptr, + TF_Tensor* tensor, + tensorflow::int64 num_elements, + tensorflow::int64 i) { + const char* ptr; + tensorflow::uint64 len; + TF_RETURN_IF_ERROR( + TF_StringTensor_GetPtrAndLen(tensor, num_elements, i, &ptr, &len)); + auto py_string = tensorflow::make_safe(PyString_FromStringAndSize(ptr, len)); + int success = + PyArray_SETITEM(pyarray, PyArray_ITER_DATA(i_ptr), py_string.get()); + if (success != 0) { + return errors::Internal("Error setting element ", i); + } + return Status::OK(); +} + +// Converts the given TF_Tensor to a Numpy array. +// If the returned status is OK, the caller becomes the owner of *out_array. +Status TF_Tensor_to_PyObject(TF_Tensor* tensor, PyObject** out_array) { + // A fetched operation will correspond to a null tensor, and a None + // in Python. + if (tensor == nullptr) { + Py_INCREF(Py_None); + *out_array = Py_None; + return Status::OK(); + } + + const int ndims = TF_NumDims(tensor); + gtl::InlinedVector<npy_intp, 4> dims(ndims); + tensorflow::int64 nelems = 1; + for (int i = 0; i < ndims; ++i) { + dims[i] = TF_Dim(tensor, i); + nelems *= dims[i]; + } + + // Convert TensorFlow dtype to numpy type descriptor. + int type_num; + TF_RETURN_IF_ERROR( + TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num)); + PyArray_Descr* descr = PyArray_DescrFromType(type_num); + + // Copy the TF_TensorData into a newly-created ndarray and return it. + // TODO(mrry): Perhaps investigate zero-copy approaches. This would involve + // creating an ndarray-like object that wraps the TF_Tensor buffer, and + // maps its destructor to TF_DeleteTensor. + Safe_PyObjectPtr safe_out_array = + tensorflow::make_safe(PyArray_Empty(ndims, dims.data(), descr, 0)); + if (!safe_out_array) { + return errors::Internal("Could not allocate ndarray"); + } + PyArrayObject* py_array = + reinterpret_cast<PyArrayObject*>(safe_out_array.get()); + if (PyArray_NBYTES(py_array) != TF_TensorByteSize(tensor)) { + if (TF_TensorType(tensor) == TF_STRING) { + // Copy element by element. + auto iter = tensorflow::make_safe(PyArray_IterNew(safe_out_array.get())); + for (tensorflow::int64 i = 0; i < nelems; ++i) { + auto s = + CopyStringToPyArrayElement(py_array, iter.get(), tensor, nelems, i); + if (!s.ok()) { + return s; + } + PyArray_ITER_NEXT(iter.get()); + } + } else { + return errors::Internal("ndarray was ", PyArray_NBYTES(py_array), + " bytes but TF_Tensor was ", + TF_TensorByteSize(tensor), " bytes"); + } + } else { + memcpy(py_array->data, TF_TensorData(tensor), PyArray_NBYTES(py_array)); + } + + // PyArray_Return turns rank 0 arrays into numpy scalars + *out_array = PyArray_Return( + reinterpret_cast<PyArrayObject*>(safe_out_array.release())); + return Status::OK(); +} + +tensorflow::Status TF_Status_to_Status(TF_Status* tf_status) { + TF_Code code = TF_GetCode(tf_status); + const string message(TF_Message(tf_status)); + + switch (code) { + case TF_OK: + return Status::OK(); + case TF_CANCELLED: + return errors::Cancelled(message); + case TF_UNKNOWN: + return errors::Unknown(message); + case TF_INVALID_ARGUMENT: + return errors::InvalidArgument(message); + case TF_DEADLINE_EXCEEDED: + return errors::DeadlineExceeded(message); + case TF_NOT_FOUND: + return errors::NotFound(message); + case TF_ALREADY_EXISTS: + return errors::AlreadyExists(message); + case TF_PERMISSION_DENIED: + return errors::PermissionDenied(message); + case TF_UNAUTHENTICATED: + return errors::Unauthenticated(message); + case TF_RESOURCE_EXHAUSTED: + return errors::ResourceExhausted(message); + case TF_FAILED_PRECONDITION: + return errors::FailedPrecondition(message); + case TF_ABORTED: + return errors::Aborted(message); + case TF_OUT_OF_RANGE: + return errors::OutOfRange(message); + case TF_UNIMPLEMENTED: + return errors::Unimplemented(message); + case TF_INTERNAL: + return errors::Internal(message); + case TF_UNAVAILABLE: + return errors::Unavailable(message); + case TF_DATA_LOSS: + return errors::DataLoss(message); + default: + return errors::Internal("Got error with unknown code: ", code, " ", + message); + } +} + +static bool numpy_imported = false; + +} // namespace + +Safe_PyObjectPtr make_safe(PyObject* o) { + return Safe_PyObjectPtr(o, Py_DECREF_wrapper); +} + +// Wrapper for TF_Run that converts the arguments to appropriate types. +// If *out_status is OK, the caller becomes the owner of the PyObjects +// in *out_values. +void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, + const NameVector& output_names, + const NameVector& target_nodes, Status* out_status, + PyObjectVector* out_values) { + // 0. Ensure that numpy has been imported. + if (!numpy_imported) { + import_array(); + numpy_imported = true; + } + + // 1. Convert the feed inputs to the appropriate form for TF_Run. + NameVector input_names; + Safe_PyObjectVector + py_inputs_safe; // Used to decref the input arrays on failure. + Safe_TF_TensorVector inputs_safe; // Used to delete tensors on failure. + TF_TensorVector inputs_unsafe; // Used to contain the arg to TF_Run. + + for (const auto& name_and_array : inputs) { + py_inputs_safe.emplace_back( + make_safe(reinterpret_cast<PyObject*>(name_and_array.second))); + } + + for (int i = 0; i < inputs.size(); ++i) { + input_names.push_back(inputs[i].first); + PyArrayObject* array = inputs[i].second; + + // Convert numpy dtype to TensorFlow dtype. + TF_DataType dtype; + *out_status = PyArray_TYPE_to_TF_DataType(array, &dtype); + if (!out_status->ok()) { + return; + } + + tensorflow::int64 nelems = 1; + gtl::InlinedVector<tensorflow::int64, 4> dims; + for (int i = 0; i < PyArray_NDIM(array); ++i) { + dims.push_back(PyArray_SHAPE(array)[i]); + nelems *= dims[i]; + } + + // Create a TF_Tensor based on the fed data. In the case of non-string data + // type, this steals a reference to array, which will be relinquished when + // the underlying buffer is deallocated. For string, a new temporary buffer + // is allocated into which the strings are encoded. + if (dtype != TF_STRING) { + // NOTE(mrry): We currently copy the numpy array into a new + // buffer to avoid possible issues on deallocation (such as + // having to acquire the Python Global Interpreter Lock). + // TODO(mrry): Investigate in what cases we can safely acquire + size_t size = PyArray_NBYTES(array); + // NOTE(mrry): 32 is the upper bound on current alignment + // requirements for tensorflow::Tensor. We hard code this here to + // avoid taking a dependency on Eigen in the client code. + void* data = tensorflow::cpu_allocator()->AllocateRaw(32, size); + std::memcpy(data, array->data, size); + inputs_safe.emplace_back(make_safe( + TF_NewTensor(dtype, dims.data(), dims.size(), data, size, + [](void* data, size_t len, void* arg) { + tensorflow::cpu_allocator()->DeallocateRaw(data); + }, + nullptr))); + // The destruction of the numpy array will now be handled by the + // inputs_safe destructor. + py_inputs_safe[i].reset(); + } else { + size_t size; + void* encoded; + Status s = EncodePyStringArray(array, nelems, &size, &encoded); + if (!s.ok()) { + *out_status = s; + return; + } + inputs_safe.emplace_back( + make_safe(TF_NewTensor(dtype, dims.data(), dims.size(), encoded, size, + [](void* data, size_t len, void* arg) { + delete[] reinterpret_cast<char*>(data); + }, + array))); + // The destruction of the numpy array will now be handled by the + // inputs_safe destructor. + py_inputs_safe[i].reset(); + } + inputs_unsafe.push_back(inputs_safe.back().get()); + } + + // 2. Allocate a container for the output data. + TF_TensorVector outputs(output_names.size()); + + Safe_TF_StatusPtr status = make_safe(TF_NewStatus()); + + // 3. Actually call TF_Run(). + Py_BEGIN_ALLOW_THREADS; + TF_Run(session, input_names.data(), inputs_unsafe.data(), input_names.size(), + const_cast<const char**>(output_names.data()), outputs.data(), + output_names.size(), const_cast<const char**>(target_nodes.data()), + target_nodes.size(), status.get()); + Py_END_ALLOW_THREADS; + + // 4. The TensorFlow runtime has taken ownership of the fed tensors, + // so we release the safe pointers to them. + for (auto& input : inputs_safe) { + input.release(); + } + + if (TF_GetCode(status.get()) != TF_OK) { + *out_status = TF_Status_to_Status(status.get()); + return; + } + + // 5. We now own the fetched tensors, so set up a safe container to + // delete them when we exit this scope. + Safe_TF_TensorVector tf_outputs_safe; + for (const auto& output : outputs) { + tf_outputs_safe.emplace_back(make_safe(output)); + } + + // 6. Convert the fetched tensors into numpy ndarrays. Store them in a safe + // container so that we do not leak + Safe_PyObjectVector py_outputs_safe; + for (int i = 0; i < output_names.size(); ++i) { + PyObject* py_array; + *out_status = TF_Tensor_to_PyObject(outputs[i], &py_array); + if (!out_status->ok()) { + return; + } + py_outputs_safe.emplace_back(make_safe(py_array)); + } + + // 7. If we reach this point, we have successfully built a list of objects + // so we can release them from the safe container. + for (auto& output : py_outputs_safe) { + out_values->push_back(output.release()); + } + *out_status = Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h new file mode 100644 index 0000000000..12a7527ed9 --- /dev/null +++ b/tensorflow/python/client/tf_session_helper.h @@ -0,0 +1,56 @@ +#ifndef TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ +#define TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ + +#include <Python.h> + +#include "numpy/arrayobject.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor_c_api.h" + +namespace tensorflow { + +// Container types for the various arguments and temporary values used +// in the wrapper. + +// A FeedVector is a vector of tensor name and numpy array pairs. The +// name is a borrowed C string. +typedef tensorflow::gtl::InlinedVector<std::pair<const char*, PyArrayObject*>, + 8> FeedVector; + +// A NameVector is a vector of tensor or operation names, as borrowed +// C strings. +typedef tensorflow::gtl::InlinedVector<const char*, 8> NameVector; + +// A PyObjectVector is a vector of borrowed pointers to PyObjects. +typedef tensorflow::gtl::InlinedVector<PyObject*, 8> PyObjectVector; + +// Safe containers for (an) owned PyObject(s). On destruction, the +// reference count of the contained object will be decremented. +inline void Py_DECREF_wrapper(PyObject* o) { Py_DECREF(o); } +typedef void (*Py_DECREF_wrapper_type)(PyObject*); +typedef std::unique_ptr<PyObject, Py_DECREF_wrapper_type> Safe_PyObjectPtr; +typedef std::vector<Safe_PyObjectPtr> Safe_PyObjectVector; +Safe_PyObjectPtr make_safe(PyObject* o); + +// Run the graph associated with the session starting with the +// supplied inputs[]. Regardless of success of failure, inputs[] are +// stolen by the implementation (i.e. the implementation will +// eventually call Py_DECREF on each array input). +// +// On success, the tensors corresponding to output_names[0,noutputs-1] +// are placed in out_values[], and these outputs[] become the property +// of the caller (the caller must eventually call Py_DECREF on them). +// +// On failure, out_status contains a tensorflow::Status with an error +// message. +void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, + const NameVector& output_names, + const NameVector& target_nodes, Status* out_status, + PyObjectVector* out_values); + +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ |