aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/client')
-rwxr-xr-xtensorflow/python/client/__init__.py0
-rw-r--r--tensorflow/python/client/client_lib.py40
-rw-r--r--tensorflow/python/client/events_writer.i34
-rw-r--r--tensorflow/python/client/events_writer_test.py54
-rw-r--r--tensorflow/python/client/graph_util.py138
-rw-r--r--tensorflow/python/client/graph_util_test.py126
-rw-r--r--tensorflow/python/client/notebook.py104
-rw-r--r--tensorflow/python/client/session.py567
-rw-r--r--tensorflow/python/client/session_test.py555
-rw-r--r--tensorflow/python/client/tensorflow_server.i16
-rw-r--r--tensorflow/python/client/test_construction_fails_op.cc22
-rw-r--r--tensorflow/python/client/tf_session.i235
-rw-r--r--tensorflow/python/client/tf_session_helper.cc518
-rw-r--r--tensorflow/python/client/tf_session_helper.h56
14 files changed, 2465 insertions, 0 deletions
diff --git a/tensorflow/python/client/__init__.py b/tensorflow/python/client/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/client/__init__.py
diff --git a/tensorflow/python/client/client_lib.py b/tensorflow/python/client/client_lib.py
new file mode 100644
index 0000000000..9148ed17c0
--- /dev/null
+++ b/tensorflow/python/client/client_lib.py
@@ -0,0 +1,40 @@
+# pylint: disable=wildcard-import,unused-import,g-bad-import-order,line-too-long
+"""This library contains classes for launching graphs and executing operations.
+
+The [basic usage](../../get_started/index.md#basic-usage) guide has
+examples of how a graph is launched in a [`tf.Session`](#Session).
+
+## Session management
+
+@@Session
+
+@@get_default_session
+
+## Error classes
+
+@@OpError
+@@CancelledError
+@@UnknownError
+@@InvalidArgumentError
+@@DeadlineExceededError
+@@NotFoundError
+@@AlreadyExistsError
+@@PermissionDeniedError
+@@UnauthenticatedError
+@@ResourceExhaustedError
+@@FailedPreconditionError
+@@AbortedError
+@@OutOfRangeError
+@@UnimplementedError
+@@InternalError
+@@UnavailableError
+@@DataLossError
+"""
+
+from tensorflow.python.client.session import InteractiveSession
+from tensorflow.python.client.session import Session
+
+from tensorflow.python.framework import errors
+from tensorflow.python.framework.errors import OpError
+
+from tensorflow.python.framework.ops import get_default_session
diff --git a/tensorflow/python/client/events_writer.i b/tensorflow/python/client/events_writer.i
new file mode 100644
index 0000000000..cbf42e2791
--- /dev/null
+++ b/tensorflow/python/client/events_writer.i
@@ -0,0 +1,34 @@
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/core/util/events_writer.h"
+#include "tensorflow/core/util/event.pb.h"
+%}
+
+%nodefaultctor EventsWriter;
+
+%ignoreall
+%unignore tensorflow;
+%unignore tensorflow::EventsWriter;
+%unignore tensorflow::EventsWriter::EventsWriter;
+%unignore tensorflow::EventsWriter::~EventsWriter;
+%unignore tensorflow::EventsWriter::FileName;
+%rename("_WriteSerializedEvent") tensorflow::EventsWriter::WriteSerializedEvent;
+%unignore tensorflow::EventsWriter::Flush;
+%unignore tensorflow::EventsWriter::Close;
+%include "tensorflow/core/util/events_writer.h"
+%unignoreall
+
+%newobject tensorflow::EventsWriter::EventsWriter;
+
+
+%extend tensorflow::EventsWriter {
+%insert("python") %{
+ def WriteEvent(self, event):
+ from tensorflow.core.util.event_pb2 import Event
+ if not isinstance(event, Event):
+ raise TypeError("Expected an event_pb2.Event proto, "
+ " but got %s" % type(event))
+ return self._WriteSerializedEvent(event.SerializeToString())
+%}
+}
diff --git a/tensorflow/python/client/events_writer_test.py b/tensorflow/python/client/events_writer_test.py
new file mode 100644
index 0000000000..60bce49b1f
--- /dev/null
+++ b/tensorflow/python/client/events_writer_test.py
@@ -0,0 +1,54 @@
+"""Tests for the SWIG-wrapped events writer."""
+import os.path
+
+from tensorflow.core.framework import summary_pb2
+from tensorflow.core.util import event_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.lib.io import tf_record
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class PywrapeventsWriterTest(test_util.TensorFlowTestCase):
+
+ def testWriteEvents(self):
+ file_prefix = os.path.join(self.get_temp_dir(), "events")
+ writer = pywrap_tensorflow.EventsWriter(file_prefix)
+ filename = writer.FileName()
+ event_written = event_pb2.Event(
+ wall_time=123.45, step=67,
+ summary=summary_pb2.Summary(
+ value=[summary_pb2.Summary.Value(tag="foo", simple_value=89.0)]))
+ writer.WriteEvent(event_written)
+ writer.Flush()
+ writer.Close()
+
+ with self.assertRaises(IOError):
+ for r in tf_record.tf_record_iterator(filename + "DOES_NOT_EXIST"):
+ self.assertTrue(False)
+
+ reader = tf_record.tf_record_iterator(filename)
+ event_read = event_pb2.Event()
+
+ event_read.ParseFromString(next(reader))
+ self.assertTrue(event_read.HasField("file_version"))
+
+ event_read.ParseFromString(next(reader))
+ # Second event
+ self.assertProtoEquals("""
+ wall_time: 123.45 step: 67
+ summary { value { tag: 'foo' simple_value: 89.0 } }
+ """, event_read)
+
+ with self.assertRaises(StopIteration):
+ next(reader)
+
+ def testWriteEventInvalidType(self):
+ class _Invalid(object):
+ def __str__(self): return "Invalid"
+ with self.assertRaisesRegexp(TypeError, "Invalid"):
+ pywrap_tensorflow.EventsWriter("foo").WriteEvent(_Invalid())
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/client/graph_util.py b/tensorflow/python/client/graph_util.py
new file mode 100644
index 0000000000..4c65a445ae
--- /dev/null
+++ b/tensorflow/python/client/graph_util.py
@@ -0,0 +1,138 @@
+"""Helpers to manipulate a tensor graph in python.
+"""
+
+import tensorflow.python.platform
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.platform import logging
+
+_VARIABLE_OPS = {
+ "Assign",
+ "AssignAdd",
+ "AssignSub",
+ "Queue",
+ "RandomParameters",
+ "ScatterAdd",
+ "ScatterSub",
+ "ScatterUpdate",
+ "Variable",
+}
+
+
+def _is_variable_op(op):
+ """Returns true if 'op' refers to a Variable node."""
+ return op in _VARIABLE_OPS
+
+
+def set_cpu0(device_string):
+ """Creates a new device string based on `device_string' but using /CPU:0.
+
+ If the device is already on /CPU:0, this is a no-op.
+
+ Args:
+ device_string: A device string.
+
+ Returns:
+ A device string.
+ """
+ parsed_device = pydev.from_string(device_string)
+ parsed_device.device_type = "CPU"
+ parsed_device.device_index = 0
+ return parsed_device.to_string()
+
+
+def must_run_on_cpu(node, pin_variables_on_cpu=False):
+ """Returns True if the given node_def must run on CPU, otherwise False.
+
+ Args:
+ node: The node to be assigned to a device. Could be either an ops.Operation
+ or NodeDef.
+ pin_variables_on_cpu: If True, this function will return False if node_def
+ represents a variable-related op.
+
+ Returns:
+ True if the given node must run on CPU, otherwise False.
+ """
+
+ if isinstance(node, ops.Operation):
+ node_def = node.node_def
+ else:
+ assert isinstance(node, graph_pb2.NodeDef)
+ node_def = node
+
+ # If the op is a variable-related op, should we pin it on CPU?
+ if pin_variables_on_cpu and _is_variable_op(node_def.op):
+ return True
+
+ # Constant operations producing a string or int32 must run on CPU.
+ if node_def.op == "Const":
+ # Get the value of the 'dtype' attr
+ dtype = node_def.attr["dtype"].type
+ if dtype == types.string or dtype == types.int32:
+ return True
+
+ if node_def.op == "DynamicStitch":
+ dtype = node_def.attr["T"].type
+ if dtype == types.int32:
+ # DynamicStitch on GPU only works for int32 values.
+ return True
+
+ if node_def.op in ["Cast"]:
+ dtype = node_def.attr["SrcT"].type
+ if dtype == types.int32:
+ # Cast on GPU does not works for int32 values.
+ return True
+ return False
+
+
+################################################################################
+#
+# device functions for use in with g.device(...)
+#
+################################################################################
+
+
+def pin_variables_on_cpu(op):
+ """Returns a CPU device for Variable nodes if the device is not specified.
+
+ Args:
+ op: The ops.Operation object describing the node for which a device
+ should be chosen. The op.device field is respected.
+
+ Returns:
+ A device containing "/device:CPU:0" if the node is related to a variable.
+ """
+ device = op.device if op.device is not None else ""
+ dev = pydev.from_string(device)
+
+ # If a device type exists already, do not override.
+ if dev.device_type:
+ return device
+
+ if isinstance(op, ops.Operation):
+ node_def = op.node_def
+ else:
+ assert isinstance(op, graph_pb2.NodeDef)
+ node_def = op
+
+ if _is_variable_op(node_def.op):
+ return set_cpu0(device)
+ return device
+
+
+def pin_to_cpu(op):
+ """Returns a CPU device for the given node."""
+ device = op.device if op.device is not None else ""
+ dev = pydev.from_string(device)
+
+ if not dev.device_type:
+ return set_cpu0(device)
+ if dev.device_type == "CPU":
+ return device
+
+ logging.info("Operation %s has been assigned to a non-CPU (%s), so "
+ "it will not be pinned to the CPU.", op.name, dev.device_type)
+ return device
diff --git a/tensorflow/python/client/graph_util_test.py b/tensorflow/python/client/graph_util_test.py
new file mode 100644
index 0000000000..8066f722a8
--- /dev/null
+++ b/tensorflow/python/client/graph_util_test.py
@@ -0,0 +1,126 @@
+"""Tests for tensorflow.python.client.graph_util."""
+import tensorflow.python.platform
+
+from tensorflow.python.client import graph_util
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_ops
+# pylint: disable=unused-import
+from tensorflow.python.ops import math_ops
+# pylint: enable=unused-import
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import googletest
+
+
+class DeviceFunctionsTest(googletest.TestCase):
+
+ def testPinToCpu(self):
+ with ops.Graph().as_default() as g, g.device(graph_util.pin_to_cpu):
+ const_a = constant_op.constant(5.0)
+ const_b = constant_op.constant(10.0)
+ add_c = const_a + const_b
+ var_v = state_ops.variable_op([], dtype=types.float32)
+ assign_c_to_v = state_ops.assign(var_v, add_c)
+ const_string = constant_op.constant("on a cpu")
+ dynamic_stitch_int_result = data_flow_ops.dynamic_stitch(
+ [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
+ dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
+ [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
+ self.assertEqual(const_a.device, "/device:CPU:0")
+ self.assertEqual(const_b.device, "/device:CPU:0")
+ self.assertEqual(add_c.device, "/device:CPU:0")
+ self.assertEqual(var_v.device, "/device:CPU:0")
+ self.assertEqual(assign_c_to_v.device, "/device:CPU:0")
+ self.assertEqual(const_string.device, "/device:CPU:0")
+ self.assertEqual(dynamic_stitch_int_result.device, "/device:CPU:0")
+ self.assertEqual(dynamic_stitch_float_result.device, "/device:CPU:0")
+
+ def testPinRequiredOpsOnCPU(self):
+ with ops.Graph().as_default() as g, g.device(
+ graph_util.pin_variables_on_cpu):
+ const_a = constant_op.constant(5.0)
+ const_b = constant_op.constant(10.0)
+ add_c = const_a + const_b
+ var_v = state_ops.variable_op([], dtype=types.float32)
+ assign_c_to_v = state_ops.assign(var_v, add_c)
+ dynamic_stitch_int_result = data_flow_ops.dynamic_stitch(
+ [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
+ dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
+ [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
+ # Non-variable ops shuld not specify a device
+ self.assertEqual(const_a.device, None)
+ self.assertEqual(const_b.device, None)
+ self.assertEqual(add_c.device, None)
+ # Variable ops specify a device
+ self.assertEqual(var_v.device, "/device:CPU:0")
+ self.assertEqual(assign_c_to_v.device, "/device:CPU:0")
+
+ def testTwoDeviceFunctions(self):
+ with ops.Graph().as_default() as g:
+ var_0 = state_ops.variable_op([1], dtype=types.float32)
+ with g.device(graph_util.pin_variables_on_cpu):
+ var_1 = state_ops.variable_op([1], dtype=types.float32)
+ var_2 = state_ops.variable_op([1], dtype=types.float32)
+ var_3 = state_ops.variable_op([1], dtype=types.float32)
+ with g.device(graph_util.pin_variables_on_cpu):
+ var_4 = state_ops.variable_op([1], dtype=types.float32)
+ with g.device("/device:GPU:0"):
+ var_5 = state_ops.variable_op([1], dtype=types.float32)
+ var_6 = state_ops.variable_op([1], dtype=types.float32)
+
+ self.assertEqual(var_0.device, None)
+ self.assertEqual(var_1.device, "/device:CPU:0")
+ self.assertEqual(var_2.device, None)
+ self.assertEqual(var_3.device, None)
+ self.assertEqual(var_4.device, "/device:CPU:0")
+ self.assertEqual(var_5.device, "/device:GPU:0")
+ self.assertEqual(var_6.device, "/device:CPU:0")
+
+ def testExplicitDevice(self):
+ with ops.Graph().as_default() as g:
+ const_0 = constant_op.constant(5.0)
+ with g.device("/device:GPU:0"):
+ const_1 = constant_op.constant(5.0)
+ with g.device("/device:GPU:1"):
+ const_2 = constant_op.constant(5.0)
+ with g.device("/device:CPU:0"):
+ const_3 = constant_op.constant(5.0)
+ with g.device("/device:CPU:1"):
+ const_4 = constant_op.constant(5.0)
+ with g.device("/job:ps"):
+ const_5 = constant_op.constant(5.0)
+
+ self.assertEqual(const_0.device, None)
+ self.assertEqual(const_1.device, "/device:GPU:0")
+ self.assertEqual(const_2.device, "/device:GPU:1")
+ self.assertEqual(const_3.device, "/device:CPU:0")
+ self.assertEqual(const_4.device, "/device:CPU:1")
+ self.assertEqual(const_5.device, "/job:ps")
+
+ def testDefaultDevice(self):
+ with ops.Graph().as_default() as g, g.device(
+ graph_util.pin_variables_on_cpu):
+ with g.device("/job:ps"):
+ const_0 = constant_op.constant(5.0)
+ with g.device("/device:GPU:0"):
+ const_1 = constant_op.constant(5.0)
+ with g.device("/device:GPU:1"):
+ const_2 = constant_op.constant(5.0)
+ with g.device("/device:CPU:0"):
+ const_3 = constant_op.constant(5.0)
+ with g.device("/device:CPU:1"):
+ const_4 = constant_op.constant(5.0)
+ with g.device("/replica:0"):
+ const_5 = constant_op.constant(5.0)
+
+ self.assertEqual(const_0.device, "/job:ps")
+ self.assertEqual(const_1.device, "/device:GPU:0")
+ self.assertEqual(const_2.device, "/device:GPU:1")
+ self.assertEqual(const_3.device, "/device:CPU:0")
+ self.assertEqual(const_4.device, "/device:CPU:1")
+ self.assertEqual(const_5.device, "/replica:0")
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/client/notebook.py b/tensorflow/python/client/notebook.py
new file mode 100644
index 0000000000..1871fbc632
--- /dev/null
+++ b/tensorflow/python/client/notebook.py
@@ -0,0 +1,104 @@
+"""Notebook front-end to TensorFlow.
+
+When you run this binary, you'll see something like below, which indicates
+the serving URL of the notebook:
+
+ The IPython Notebook is running at: http://127.0.0.1:8888/
+
+Press "Shift+Enter" to execute a cell
+Press "Enter" on a cell to go into edit mode.
+Press "Escape" to go back into command mode and use arrow keys to navigate.
+Press "a" in command mode to insert cell above or "b" to insert cell below.
+
+Your root notebooks directory is FLAGS.notebook_dir
+"""
+
+
+import os
+import socket
+import sys
+
+# pylint: disable=g-import-not-at-top
+# Official recommended way of turning on fast protocol buffers as of 10/21/14
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2"
+
+from tensorflow.python.platform import app
+from tensorflow.python.platform import flags
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string(
+ "password", None,
+ "Password to require. If set, the server will allow public access."
+ " Only used if notebook config file does not exist.")
+
+flags.DEFINE_string("notebook_dir", "experimental/brain/notebooks",
+ "root location where to store notebooks")
+
+ORIG_ARGV = sys.argv
+# Main notebook process calls itself with argv[1]="kernel" to start kernel
+# subprocesses.
+IS_KERNEL = len(sys.argv) > 1 and sys.argv[1] == "kernel"
+
+
+def main(unused_argv):
+ sys.argv = ORIG_ARGV
+
+ if not IS_KERNEL:
+ # Drop all flags.
+ sys.argv = [sys.argv[0]]
+ # NOTE(sadovsky): For some reason, putting this import at the top level
+ # breaks inline plotting. It's probably a bug in the stone-age version of
+ # matplotlib.
+ from IPython.html.notebookapp import NotebookApp # pylint: disable=g-import-not-at-top
+ notebookapp = NotebookApp.instance()
+ notebookapp.open_browser = True
+
+ # password functionality adopted from quality/ranklab/main/tools/notebook.py
+ # add options to run with "password"
+ if FLAGS.password:
+ from IPython.lib import passwd # pylint: disable=g-import-not-at-top
+ notebookapp.ip = "0.0.0.0"
+ notebookapp.password = passwd(FLAGS.password)
+ else:
+ print ("\nNo password specified; Notebook server will only be available"
+ " on the local machine.\n")
+ notebookapp.initialize(argv=["--notebook-dir", FLAGS.notebook_dir])
+
+ if notebookapp.ip == "0.0.0.0":
+ proto = "https" if notebookapp.certfile else "http"
+ url = "%s://%s:%d%s" % (proto, socket.gethostname(), notebookapp.port,
+ notebookapp.base_project_url)
+ print "\nNotebook server will be publicly available at: %s\n" % url
+
+ notebookapp.start()
+ return
+
+ # Drop the --flagfile flag so that notebook doesn't complain about an
+ # "unrecognized alias" when parsing sys.argv.
+ sys.argv = ([sys.argv[0]] +
+ [z for z in sys.argv[1:] if not z.startswith("--flagfile")])
+ from IPython.kernel.zmq.kernelapp import IPKernelApp # pylint: disable=g-import-not-at-top
+ kernelapp = IPKernelApp.instance()
+ kernelapp.initialize()
+
+ # Enable inline plotting. Equivalent to running "%matplotlib inline".
+ ipshell = kernelapp.shell
+ ipshell.enable_matplotlib("inline")
+
+ kernelapp.start()
+
+
+if __name__ == "__main__":
+ # When the user starts the main notebook process, we don't touch sys.argv.
+ # When the main process launches kernel subprocesses, it writes all flags
+ # to a tmpfile and sets --flagfile to that tmpfile, so for kernel
+ # subprocesses here we drop all flags *except* --flagfile, then call
+ # app.run(), and then (in main) restore all flags before starting the
+ # kernel app.
+ if IS_KERNEL:
+ # Drop everything except --flagfile.
+ sys.argv = ([sys.argv[0]] +
+ [x for x in sys.argv[1:] if x.startswith("--flagfile")])
+ app.run()
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
new file mode 100644
index 0000000000..7da9b41cf4
--- /dev/null
+++ b/tensorflow/python/client/session.py
@@ -0,0 +1,567 @@
+"""A client interface for TensorFlow."""
+
+import re
+import sys
+import threading
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python import pywrap_tensorflow as tf_session
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import logging
+
+
+class SessionInterface(object):
+ """Base class for implementations of TensorFlow client sessions."""
+
+ @property
+ def graph(self):
+ """The underlying TensorFlow graph, to be used in building Operations."""
+ raise NotImplementedError('graph')
+
+ @property
+ def sess_str(self):
+ """The TensorFlow process to which this session will connect."""
+ raise NotImplementedError('sess_str')
+
+ def run(self, fetches, feed_dict=None):
+ """Runs operations in the session. See `Session.run()` for details."""
+ raise NotImplementedError('Run')
+
+
+class BaseSession(SessionInterface):
+ """A class for interacting with a TensorFlow computation.
+
+ The BaseSession enables incremental graph building with inline
+ execution of Operations and evaluation of Tensors.
+ """
+
+ def __init__(self, target='', graph=None, config=None):
+ """Constructs a new TensorFlow session.
+
+ Args:
+ target: (Optional) The TensorFlow execution engine to connect to.
+ graph: (Optional) The graph to be used. If this argument is None,
+ the default graph will be used.
+ config: (Optional) ConfigProto proto used to configure the session.
+
+ Raises:
+ RuntimeError: If an error occurs while creating the TensorFlow
+ session.
+ """
+ if graph is None:
+ self._graph = ops.get_default_graph()
+ else:
+ self._graph = graph
+
+ self._opened = False
+ self._closed = False
+
+ self._current_version = 0
+ self._extend_lock = threading.Lock()
+ self._target = target
+
+ self._session = None
+
+ try:
+ opts = tf_session.TF_NewSessionOptions(target=target, config=config)
+ status = tf_session.TF_NewStatus()
+ self._session = tf_session.TF_NewSession(opts, status)
+ if tf_session.TF_GetCode(status) != 0:
+ message = tf_session.TF_Message(status)
+ raise RuntimeError(message)
+
+ finally:
+ tf_session.TF_DeleteSessionOptions(opts)
+ tf_session.TF_DeleteStatus(status)
+
+ def close(self):
+ """Closes this session.
+
+ Calling this method frees all resources associated with the session.
+
+ Raises:
+ RuntimeError: If an error occurs while closing the session.
+ """
+ with self._extend_lock:
+ if self._opened and not self._closed:
+ self._closed = True
+ try:
+ status = tf_session.TF_NewStatus()
+ tf_session.TF_CloseSession(self._session, status)
+ if tf_session.TF_GetCode(status) != 0:
+ raise RuntimeError(tf_session.TF_Message(status))
+ finally:
+ tf_session.TF_DeleteStatus(status)
+
+ def __del__(self):
+ self.close()
+ try:
+ status = tf_session.TF_NewStatus()
+ if self._session is not None:
+ tf_session.TF_DeleteSession(self._session, status)
+ if tf_session.TF_GetCode(status) != 0:
+ raise RuntimeError(tf_session.TF_Message(status))
+ self._session = None
+ finally:
+ tf_session.TF_DeleteStatus(status)
+
+ @property
+ def graph(self):
+ """The graph that was launched in this session."""
+ return self._graph
+
+ @property
+ def graph_def(self):
+ """A serializable version of the underlying TensorFlow graph.
+
+ Returns:
+ A graph_pb2.GraphDef proto containing nodes for all of the Operations in
+ the underlying TensorFlow graph.
+ """
+ return self._graph.as_graph_def()
+
+ @property
+ def sess_str(self):
+ return self._target
+
+ def as_default(self):
+ """Returns a context manager that makes this object the default session.
+
+ Use with the `with` keyword to specify that calls to
+ [`Operation.run()`](framework.md#Operation.run) or
+ [`Tensor.run()`](framework.md#Tensor.run) should be executed in
+ this session.
+
+ ```python
+ c = tf.constant(..)
+ sess = tf.Session()
+
+ with sess.as_default():
+ assert tf.get_default_session() is sess
+ print c.eval()
+ ```
+
+ To get the current default session, use
+ [`tf.get_default_session()`](#get_default_session).
+
+
+ *N.B.* The `as_default` context manager *does not* close the
+ session when you exit the context, and you must close the session
+ explicitly.
+
+ ```python
+ c = tf.constant(...)
+ sess = tf.Session()
+ with sess.as_default():
+ print c.eval()
+ # ...
+ with sess.as_default():
+ print c.eval()
+
+ sess.close()
+ ```
+
+ Alternatively, you can use `with tf.Session():` to create a
+ session that is automatically closed on exiting the context,
+ including when an uncaught exception is raised.
+
+ *N.B.* The default graph is a property of the current thread. If you
+ create a new thread, and wish to use the default session in that
+ thread, you must explicitly add a `with sess.as_default():` in that
+ thread's function.
+
+ Returns:
+ A context manager using this session as the default session.
+
+ """
+ return ops.default_session(self)
+
+ # Eventually, this registration could be opened up to support custom
+ # Tensor expansions. Expects tuples of (Type, fetch_fn, feed_fn),
+ # where the signatures are:
+ # fetch_fn : Type -> (list of Tensors,
+ # lambda: list of fetched np.ndarray -> TypeVal)
+ # feed_fn : Type, TypeVal -> list of (Tensor, value)
+ # Conceptually, fetch_fn describes how to expand fetch into its
+ # component Tensors and how to contracting the fetched results back into
+ # a single return value. feed_fn describes how to unpack a single fed
+ # value and map it to feeds of a Tensor and its corresponding value.
+ # pylint: disable=g-long-lambda
+ _REGISTERED_EXPANSIONS = [
+ # SparseTensors are fetched as SparseTensorValues. They can be fed
+ # SparseTensorValues or normal tuples.
+ (ops.SparseTensor,
+ lambda fetch: (
+ [fetch.indices, fetch.values, fetch.shape],
+ lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)),
+ lambda feed, feed_val: list(zip(
+ [feed.indices, feed.values, feed.shape], feed_val))),
+ # The default catches all types and performs no expansions.
+ (object,
+ lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
+ lambda feed, feed_val: [(feed, feed_val)])]
+ # pylint: enable=g-long-lambda
+
+ def run(self, fetches, feed_dict=None):
+ """Runs the operations and evaluates the tensors in `fetches`.
+
+ This method runs one "step" of TensorFlow computation, by
+ running the necessary graph fragment to execute every `Operation`
+ and evaluate every `Tensor` in `fetches`, substituting the values in
+ `feed_dict` for the corresponding input values.
+
+ The `fetches` argument may be a list of graph elements or a single
+ graph element, and these determine the return value of this
+ method. A graph element can be one of the following types:
+
+ * If the *i*th element of `fetches` is an
+ [`Operation`](framework.md#Operation), the *i*th return value
+ will be `None`.
+ * If the *i*th element of `fetches` is a
+ [`Tensor`](framework.md#Tensor), the *i*th return value will
+ be a numpy ndarray containing the value of that tensor.
+ * If the *i*th element of `fetches` is a
+ [`SparseTensor`](sparse_ops.md#SparseTensor), the *i*th
+ return value will be a
+ [`SparseTensorValue`](sparse_ops.md#SparseTensorValue)
+ containing the value of that sparse tensor.
+
+ The optional `feed_dict` argument allows the caller to override
+ the value of tensors in the graph. Each key in `feed_dict` can be
+ one of the following types:
+
+ * If the key is a [`Tensor`](framework.md#Tensor), the
+ value may be a Python scalar, string, list, or numpy ndarray
+ that can be converted to the same `dtype` as that
+ tensor. Additionally, if the key is a
+ [placeholder](io_ops.md#placeholder), the shape of the value
+ will be checked for compatibility with the placeholder.
+ * If the key is a [`SparseTensor`](sparse_ops.md#SparseTensor),
+ the value should be a
+ [`SparseTensorValue`](sparse_ops.md#SparseTensorValue).
+
+ Args:
+ fetches: A single graph element, or a list of graph elements
+ (described above).
+ feed_dict: A dictionary that maps graph elements to values
+ (described above).
+
+ Returns:
+ Either a single value if `fetches` is a single graph element, or
+ a list of values if `fetches` is a list (described above).
+
+ Raises:
+ RuntimeError: If this `Session` is in an invalid state (e.g. has been
+ closed).
+ TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
+ ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
+ `Tensor` that doesn't exist.
+
+ """
+ def _fetch_fn(fetch):
+ for tensor_type, fetch_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
+ if isinstance(fetch, tensor_type):
+ return fetch_fn(fetch)
+ raise TypeError('Fetch argument %r has invalid type %r'
+ % (fetch, type(fetch)))
+
+ def _feed_fn(feed, feed_val):
+ for tensor_type, _, feed_fn in BaseSession._REGISTERED_EXPANSIONS:
+ if isinstance(feed, tensor_type):
+ return feed_fn(feed, feed_val)
+ raise TypeError('Feed argument %r has invalid type %r'
+ % (feed, type(feed)))
+
+ # Check session.
+ if self._closed:
+ raise RuntimeError('Attempted to use a closed Session.')
+
+ # Validate and process fetches.
+ is_list_fetch = isinstance(fetches, (list, tuple))
+ if not is_list_fetch:
+ fetches = [fetches]
+
+ unique_fetch_targets = set()
+ target_list = []
+
+ fetch_info = []
+ for fetch in fetches:
+ subfetches, fetch_contraction_fn = _fetch_fn(fetch)
+ subfetch_names = []
+ for subfetch in subfetches:
+ try:
+ fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True,
+ allow_operation=True)
+ if isinstance(fetch_t, ops.Operation):
+ target_list.append(fetch_t.name)
+ else:
+ subfetch_names.append(fetch_t.name)
+ except TypeError as e:
+ raise TypeError('Fetch argument %r of %r has invalid type %r, '
+ 'must be a string or Tensor. (%s)'
+ % (subfetch, fetch, type(subfetch), e.message))
+ except ValueError as e:
+ raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
+ 'Tensor. (%s)' % (subfetch, fetch, e.message))
+ except KeyError as e:
+ raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
+ 'Tensor. (%s)' % (subfetch, fetch, e.message))
+ unique_fetch_targets.update(subfetch_names)
+ fetch_info.append((subfetch_names, fetch_contraction_fn))
+
+ unique_fetch_targets = list(unique_fetch_targets)
+
+ # Create request.
+ feed_dict_string = {}
+
+ # Validate and process feed_dict.
+ if feed_dict:
+ for feed, feed_val in feed_dict.iteritems():
+ for subfeed, subfeed_val in _feed_fn(feed, feed_val):
+ try:
+ subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
+ allow_operation=False)
+ except Exception as e:
+ e.message = ('Cannot interpret feed_dict key as Tensor: '
+ + e.message)
+ e.args = (e.message,)
+ raise e
+ np_val = np.array(subfeed_val, dtype=subfeed_t.dtype.as_numpy_dtype)
+ if subfeed_t.op.type == 'Placeholder':
+ if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
+ raise ValueError(
+ 'Cannot feed value of shape %r for Tensor %r, '
+ 'which has shape %r'
+ % (np_val.shape, subfeed_t.name,
+ tuple(subfeed_t.get_shape().dims)))
+ feed_dict_string[str(subfeed_t.name)] = np_val
+
+ # Run request and get response.
+ results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
+
+ # User may have fetched the same tensor multiple times, but we
+ # only fetch them from the runtime once. Furthermore, they may
+ # be wrapped as a tuple of tensors. Here we map the results back
+ # to what the client asked for.
+ fetched_results = dict(zip(unique_fetch_targets, results))
+ ret = []
+ for fetch_names, fetch_contraction_fn in fetch_info:
+ if fetch_names:
+ fetched_vals = [fetched_results[name] for name in fetch_names]
+ ret.append(fetch_contraction_fn(fetched_vals))
+ else:
+ ret.append(None)
+
+ if is_list_fetch:
+ return ret
+ else:
+ return ret[0]
+
+ # Captures the name of a node in an error status.
+ _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
+
+ def _do_run(self, target_list, fetch_list, feed_dict):
+ """Runs a step based on the given fetches and feeds.
+
+ Args:
+ target_list: A list of strings corresponding to names of tensors
+ or operations to be run to, but not fetched.
+ fetch_list: A list of strings corresponding to names of tensors to be
+ fetched and operations to be run.
+ feed_dict: A dictionary that maps tensor names to numpy ndarrays.
+
+ Returns:
+ A list of numpy ndarrays, corresponding to the elements of
+ `fetch_list`. If the ith element of `fetch_list` contains the
+ name of an operation, the first Tensor output of that operation
+ will be returned for that element.
+ """
+ try:
+ # Ensure any changes to the graph are reflected in the runtime.
+ with self._extend_lock:
+ if self._graph.version > self._current_version:
+ graph_def = self._graph.as_graph_def(
+ from_version=self._current_version)
+
+ try:
+ status = tf_session.TF_NewStatus()
+ tf_session.TF_ExtendGraph(
+ self._session, graph_def.SerializeToString(), status)
+ if tf_session.TF_GetCode(status) != 0:
+ raise RuntimeError(tf_session.TF_Message(status))
+ self._opened = True
+ finally:
+ tf_session.TF_DeleteStatus(status)
+
+ self._current_version = self._graph.version
+
+ return tf_session.TF_Run(self._session, feed_dict, fetch_list,
+ target_list)
+
+ except tf_session.StatusNotOK as e:
+ e_type, e_value, e_traceback = sys.exc_info()
+ m = BaseSession._NODEDEF_NAME_RE.search(e.error_message)
+ if m is not None:
+ node_name = m.group(1)
+ node_def = None
+ try:
+ op = self._graph.get_operation_by_name(node_name)
+ node_def = op.node_def
+ except KeyError:
+ op = None
+ # pylint: disable=protected-access
+ raise errors._make_specific_exception(node_def, op, e.error_message,
+ e.code)
+ # pylint: enable=protected-access
+ raise e_type, e_value, e_traceback
+
+
+class Session(BaseSession):
+ """A class for running TensorFlow operations.
+
+ A `Session` object encapsulates the environment in which `Operation`
+ objects are executed, and `Tensor` objects are evaluated. For
+ example:
+
+ ```python
+ # Build a graph.
+ a = tf.constant(5.0)
+ b = tf.constant(6.0)
+ c = a * b
+
+ # Launch the graph in a session.
+ sess = tf.Session()
+
+ # Evaluate the tensor `c`.
+ print sess.run(c)
+ ```
+
+ A session may own resources, such as
+ [variables](state_ops.md#Variable), [queues](io_ops.md#QueueBase),
+ and [readers](io_ops.md#ReaderBase). It is important to release
+ these resources when they are no longer required. To do this, either
+ invoke the [`close()`](#Session.close) method on the session, or use
+ the session as a context manager. The following two examples are
+ equivalent:
+
+ ```python
+ # Using the `close()` method.
+ sess = tf.Session()
+ sess.run(...)
+ sess.close()
+
+ # Using the context manager.
+ with tf.Session() as sess:
+ sess.run(...)
+ ```
+
+ The [`ConfigProto`]
+ (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto)
+ protocol buffer exposes various configuration options for a
+ session. For example, to create a session that uses soft constraints
+ for device placement, and log the resulting placement decisions,
+ create a session as follows:
+
+ ```python
+ # Launch the graph in a session that allows soft device placement and
+ # logs the placement decisions.
+ sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
+ log_device_placement=True))
+ ```
+
+ @@__init__
+ @@run
+ @@close
+
+ @@graph
+
+ @@as_default
+
+ """
+
+ def __init__(self, target='', graph=None, config=None):
+ """Creates a new TensorFlow session.
+
+ If no `graph` argument is specified when constructing the session,
+ the default graph will be launched in the session. If you are
+ using more than one graph (created with `tf.Graph()` in the same
+ process, you will have to use different sessions for each graph,
+ but each graph can be used in multiple sessions. In this case, it
+ is often clearer to pass the graph to be launched explicitly to
+ the session constructor.
+
+ Args:
+ target: (Optional.) The execution engine to connect to.
+ Defaults to using an in-process engine. At present, no value
+ other than the empty string is supported.
+ graph: (Optional.) The `Graph` to be launched (described above).
+ config: (Optional.) A [`ConfigProto`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/config.proto)
+ protocol buffer with configuration options for the session.
+
+ """
+ super(Session, self).__init__(target, graph, config=config)
+ self._context_managers = [self.graph.as_default(), self.as_default()]
+
+ def __enter__(self):
+ for context_manager in self._context_managers:
+ context_manager.__enter__()
+ return self
+
+ def __exit__(self, exec_type, exec_value, exec_tb):
+ if exec_type is errors.OpError:
+ logging.error('Session closing due to OpError: %s', (exec_value,))
+
+ for context_manager in reversed(self._context_managers):
+ context_manager.__exit__(exec_type, exec_value, exec_tb)
+
+ self.close()
+
+
+class InteractiveSession(BaseSession):
+ """A TensorFlow `Session` for use in interactive contexts, such as a shell.
+
+ In some cases, such as interactive shells and IPython notebooks, it is
+ useful to be able to define a `Session` without using a with block: this
+ style enables statements to be executed immediately, rather than at the
+ termination of the block. In that case, it must be closed using
+ `Session.close()`. For example:
+
+ ```python
+ sess = InteractiveSession()
+ a = tf.constant(5.0)
+ b = tf.constant(6.0)
+ c = a * b
+ print c.run()
+ sess.close()
+ ```
+
+ @@__init__
+ @@close
+ """
+
+ def __init__(self, target='', graph=None):
+ """Initializes an `InteractiveSession` object similar to `Session`.
+
+ Args:
+ target: Optional. The TensorFlow execution engine to connect to.
+ graph: Optional. The `Graph` object to be used. If this argument is None,
+ the default graph will be used.
+ """
+ super(InteractiveSession, self).__init__(target, graph)
+ self._default_session = self.as_default()
+ self._default_session.__enter__()
+ self._explicit_graph = graph
+ if self._explicit_graph is not None:
+ self._default_graph = graph.as_default()
+ self._default_graph.__enter__()
+
+ def close(self):
+ """Closes an `InteractiveSession`."""
+ super(InteractiveSession, self).close()
+ if self._explicit_graph is not None:
+ self._default_graph.__exit__(None, None, None)
+ self._default_session.__exit__(None, None, None)
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
new file mode 100644
index 0000000000..4492840dcf
--- /dev/null
+++ b/tensorflow/python/client/session_test.py
@@ -0,0 +1,555 @@
+"""Tests for tensorflow.python.client.session.Session."""
+import threading
+import time
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.core.framework import config_pb2
+from tensorflow.core.lib.core import error_codes_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+
+# NOTE(mrry): Dummy shape registration for op used in the tests.
+ops.RegisterShape('ConstructionFails')(None)
+
+
+class SessionTest(test_util.TensorFlowTestCase):
+
+ def testUseExistingGraph(self):
+ with ops.Graph().as_default() as g, ops.device('/cpu:0'):
+ a = constant_op.constant(6.0, shape=[1, 1])
+ b = constant_op.constant(7.0, shape=[1, 1])
+ c = math_ops.matmul(a, b, name='matmul')
+ with session.Session(graph=g):
+ result = c.eval()
+ self.assertAllEqual(result, [[42.0]])
+
+ def testUseDefaultGraph(self):
+ with ops.Graph().as_default(), ops.device('/cpu:0'):
+ a = constant_op.constant(6.0, shape=[1, 1])
+ b = constant_op.constant(7.0, shape=[1, 1])
+ c = math_ops.matmul(a, b, name='matmul')
+ with session.Session():
+ result = c.eval()
+ self.assertAllEqual(result, [[42.0]])
+
+ def testCreate(self):
+ with session.Session():
+ inp = constant_op.constant(10.0, name='W1')
+ copy = array_ops.identity(inp)
+ # Test with feed.
+ # TODO(mrry): Investigate why order='F' didn't work.
+ arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C')
+ copy_val = copy.eval({'W1:0': arr})
+ self.assertAllEqual(arr, copy_val)
+ # Test without feed.
+ copy_val = copy.eval()
+ self.assertAllEqual(np.asarray(10.0, dtype=np.float32), copy_val)
+
+ def testManyCPUs(self):
+ # TODO(keveman): Implement ListDevices and test for the number of
+ # devices returned by ListDevices.
+ with session.Session(
+ config=config_pb2.ConfigProto(device_count={'CPU': 2})):
+ inp = constant_op.constant(10.0, name='W1')
+ self.assertAllEqual(inp.eval(), 10.0)
+
+ def testErrorsReported(self):
+ with session.Session() as s:
+ constant_op.constant(10.0, name='W1')
+ with self.assertRaises(ValueError):
+ s.run('foo:0')
+
+ def testErrorPayload(self):
+ with session.Session():
+ a = array_ops.placeholder(types.float32)
+ with self.assertRaisesOpError(lambda e: e.op == a.op):
+ a.eval()
+
+ def testOpConstructionErrorPayload(self):
+ with session.Session():
+ failing_op = ops.get_default_graph().create_op(
+ 'ConstructionFails', [], [], name='f')
+
+ def exc_predicate(e):
+ return (e.op == failing_op
+ and e.error_code == error_codes_pb2.INVALID_ARGUMENT)
+ with self.assertRaisesOpError(exc_predicate):
+ failing_op.run()
+
+ def testErrorBasedOn(self):
+ with session.Session() as sess:
+ a = constant_op.constant(0.0, shape=[2, 3])
+ # NOTE(mrry): The original_op is nonsense, but used here to test that the
+ # errors are reported correctly.
+ # pylint: disable=protected-access
+ with sess.graph._original_op(a.op):
+ b = array_ops.identity(a, name='id')
+ with sess.graph._original_op(b.op):
+ c = array_ops.placeholder(types.float32)
+ # pylint: enable=protected-access
+
+ def exc_predicate(e):
+ return (e.op == c.op
+ and e.op._original_op == b.op
+ and e.op._original_op._original_op == a.op)
+ with self.assertRaisesOpError(exc_predicate):
+ c.eval()
+
+ def testFetchTensorObject(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ results_with_list = s.run([c])
+ self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0])
+ results_with_single = s.run(c)
+ self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single)
+ results_with_get = c.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get)
+ a_val, b_val = s.run([a, b]) # Test multiple fetches.
+ self.assertAllEqual([[1.0, 1.0]], a_val)
+ self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val)
+
+ def testFetchScalar(self):
+ with session.Session() as s:
+ for scalar in np.int32, np.int64, np.float32, np.float64:
+ x = scalar(7)
+ y = scalar(8)
+ tf_x = constant_op.constant(x, shape=[])
+ tf_y = constant_op.constant(y)
+ tf_xy = math_ops.add(tf_x, tf_y)
+ # Single fetch
+ xy = s.run(tf_xy)
+ self.assertEqual(scalar, type(xy))
+ self.assertEqual(x + y, xy)
+ # List fetch
+ xy, = s.run([tf_xy])
+ self.assertEqual(scalar, type(xy))
+ self.assertEqual(x + y, xy)
+
+ def testFetchOperationObject(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ v = variables.Variable(a, name='testFetchOperationObject_v')
+ s.run(v.initializer)
+ v_val = s.run(v)
+ self.assertAllEqual([[1.0, 1.0]], v_val)
+
+ def testFetchSparseTensor(self):
+ with session.Session() as s:
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ shape = np.array([7, 9, 2]).astype(np.int64)
+ sp = ops.SparseTensor(
+ constant_op.constant(indices),
+ constant_op.constant(values),
+ constant_op.constant(shape))
+ # Single fetch, use as tuple
+ sp_out = s.run(sp)
+ indices_out, values_out, shape_out = sp_out
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Single fetch, use as SparseTensorValue
+ sp_out = s.run(sp)
+ self.assertAllEqual(sp_out.indices, indices)
+ self.assertAllEqual(sp_out.values, values)
+ self.assertAllEqual(sp_out.shape, shape)
+ # Tuple fetch, use as tuple
+ indices_out, values_out, shape_out = s.run(sp)
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # List fetch, use as tuple
+ (indices_out, values_out, shape_out), = s.run([sp])
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # List fetch, use as SparseTensorValue
+ sp_out, = s.run([sp])
+ self.assertAllEqual(sp_out.indices, indices)
+ self.assertAllEqual(sp_out.values, values)
+ self.assertAllEqual(sp_out.shape, shape)
+
+ def testFeedSparseTensor(self):
+ with session.Session() as s:
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ shape = np.array([7, 9, 2]).astype(np.int64)
+ sp = ops.SparseTensor(
+ array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
+ array_ops.placeholder(dtype=np.float32, shape=(2,)),
+ array_ops.placeholder(dtype=np.int64, shape=(3,)),)
+ sp_indices = array_ops.identity(sp.indices)
+ sp_values = array_ops.identity(sp.values)
+ sp_shape = array_ops.identity(sp.shape)
+ sp2 = ops.SparseTensor(sp_indices, sp_values, sp_shape)
+ # Feed with tuple
+ indices_out, values_out, shape_out = s.run(
+ [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)})
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Feed with SparseTensorValue
+ indices_out, values_out, shape_out = s.run(
+ [sp_indices, sp_values, sp_shape],
+ {sp: ops.SparseTensorValue(indices, values, shape)})
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Feed with SparseTensorValue, fetch SparseTensorValue
+ sp2_out = s.run(sp2, {sp: ops.SparseTensorValue(indices, values, shape)})
+ self.assertAllEqual(sp2_out.indices, indices)
+ self.assertAllEqual(sp2_out.values, values)
+ self.assertAllEqual(sp2_out.shape, shape)
+
+ def testExtendWithStatelessOperations(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ c_val = s.run(c)
+ self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
+ d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
+ e = math_ops.matmul(c, d)
+ # Extend will happen here.
+ e_val = s.run(e)
+ self.assertAllEqual([[24.0]], e_val)
+
+ def testExtendWithStatefulOperations(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ v = variables.Variable(c, name='testExtendWithStatefulOperations_v')
+ v.initializer.run()
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ d = constant_op.constant(3.0, shape=[2, 3])
+ e = math_ops.matmul(a, d)
+ assign_e_to_v = state_ops.assign(v, e)
+ # Extend will happen here.
+ e_val = e.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ s.run(assign_e_to_v)
+ v_val = v.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
+
+ def testExtendWithGroupBy(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ p = variables.Variable(a, name='testExtendWithGroupBy_p')
+ a_val = a.eval() # Force an Extend after this op.
+ self.assertAllEqual([[1.0, 1.0]], a_val)
+
+ b = constant_op.constant(2.0, shape=[1, 2])
+ q = variables.Variable(b, name='testExtendWithGroupBy_q')
+ # Extend will happen here.
+ init = control_flow_ops.group(p.initializer, q.initializer)
+ s.run(init)
+ p_val, q_val = s.run([p, q])
+
+ self.assertAllEqual([[1.0, 1.0]], p_val)
+ self.assertAllEqual([[2.0, 2.0]], q_val)
+
+ def testTensorGetMethod(self):
+ with session.Session():
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+
+ c_val = c.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
+
+ fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]})
+ self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val)
+
+ def testOperationRunMethod(self):
+ with session.Session():
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[1, 2], name='b')
+ v = variables.Variable(a, a.dtype)
+ assign_a_to_v = state_ops.assign(v, a)
+
+ assign_a_to_v.eval()
+
+ v_val = v.eval()
+ self.assertAllEqual([[1.0, 1.0]], v_val)
+
+ assign_b_to_v = state_ops.assign(v, b)
+
+ assign_b_to_v.eval()
+ v_val = v.eval()
+ self.assertAllEqual([[2.0, 2.0]], v_val)
+
+ assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]})
+ v_val = v.eval()
+ self.assertAllEqual([[3.0, 3.0]], v_val)
+
+ def testDefaultGraph(self):
+ with session.Session() as s:
+ self.assertEqual(ops.get_default_graph(), s.graph)
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ self.assertEqual(ops.get_default_graph(), a.graph)
+ self.assertEqual(ops.get_default_graph(), b.graph)
+ c = math_ops.matmul(a, b)
+ v = variables.Variable(c, name='testDefaultGraph_v')
+ v.initializer.run()
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ d = constant_op.constant(3.0, shape=[2, 3])
+ e = math_ops.matmul(a, d)
+ assign_e_to_v = state_ops.assign(v, e)
+ e_val = e.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ s.run(assign_e_to_v)
+ v_val = v.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
+ self.assertEqual(ops.get_default_graph(), s.graph)
+
+ def _testDefaultGraphInThread(self, constructed_event, continue_event, i):
+ with session.Session() as s:
+ self.assertEqual(ops.get_default_graph(), s.graph)
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ v = variables.Variable(c, name='var_%d' % i)
+
+ # Block here until all threads have constructed their graph.
+ constructed_event.set()
+ continue_event.wait()
+
+ assign_c_to_v = state_ops.assign(v, c)
+ v.initializer.run()
+ assign_c_to_v.eval()
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ d = constant_op.constant(3.0, shape=[2, 3])
+ e = math_ops.matmul(a, d)
+ assign_e_to_v = state_ops.assign(v, e)
+ e_val = e.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ s.run(assign_e_to_v)
+ v_val = v.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
+ self.assertEqual(ops.get_default_graph(), s.graph)
+
+ def testDefaultGraphWithThreads(self):
+ # Fork ten threads that use their thread-local default graph.
+ threads = []
+ constructed_events = [threading.Event() for _ in range(10)]
+ continue_event = threading.Event()
+ for i, constructed_event in enumerate(constructed_events):
+ t = self.checkedThread(target=self._testDefaultGraphInThread,
+ args=(constructed_event, continue_event, i))
+ threads.append(t)
+ for t in threads:
+ t.start()
+ for constructed_event in constructed_events:
+ constructed_event.wait()
+ continue_event.set()
+ for t in threads:
+ t.join()
+
+ def testParallelRun(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ ev = threading.Event()
+
+ def run_step():
+ ev.wait()
+ val = c.eval(session=sess)
+ self.assertEqual(val, 5.0)
+ threads = [self.checkedThread(target=run_step) for _ in range(100)]
+ for t in threads:
+ t.start()
+ ev.set()
+ for t in threads:
+ t.join()
+
+ def testRunFeedDict(self):
+ with session.Session() as s:
+ x = array_ops.zeros([2])
+
+ y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)})
+ self.assertAllEqual(y, 2 * np.ones(2))
+
+ y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)})
+ self.assertAllEqual(y, 2 * np.ones(2))
+
+ y = s.run(2 * x, feed_dict={x: [1, 1]})
+ assert (y == 2 * np.ones(2)).all()
+
+ def testGraphDef(self):
+ with session.Session() as sess:
+ self.assertProtoEquals('', sess.graph_def)
+ c = constant_op.constant(5.0, name='c')
+ self.assertEquals(len(sess.graph_def.node), 1)
+ d = constant_op.constant(6.0, name='d')
+ self.assertEquals(len(sess.graph_def.node), 2)
+ self.assertAllEqual(c.eval(), 5.0)
+ self.assertAllEqual(d.eval(), 6.0)
+ e = constant_op.constant(7.0, name='e')
+ self.assertEquals(len(sess.graph_def.node), 3)
+ self.assertAllEqual(e.eval(), 7.0)
+
+ def testUseAfterClose(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ self.assertAllEqual(sess.run(c), 5.0)
+ with self.assertRaisesWithPredicateMatch(
+ RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)):
+ sess.run(c)
+
+ def testUseAfterCloseConcurrent(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ self.assertAllEqual(sess.run(c), 5.0)
+
+ def update_thread():
+ with self.assertRaisesWithPredicateMatch(
+ RuntimeError,
+ lambda e: 'Attempted to use a closed Session.' in str(e)):
+ while True:
+ sess.run(c)
+ t = threading.Thread(target=update_thread)
+ t.start()
+ time.sleep(0.1)
+ sess.close()
+ t.join()
+
+ def testNotEntered(self):
+ # pylint: disable=protected-access
+ self.assertEqual(ops._default_session_stack.get_default(), None)
+ # pylint: enable=protected-access
+ with ops.device('/cpu:0'):
+ sess = session.Session()
+ c_1 = constant_op.constant(5.0)
+ with sess.graph.as_default():
+ c_2 = constant_op.constant(5.0)
+ self.assertEqual(c_1.graph, c_2.graph)
+ self.assertEqual(sess.run(c_2), 5.0)
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: 'No default session is registered.' in str(e)):
+ c_2.eval()
+
+ def testInteractive(self):
+ with ops.device('/cpu:0'):
+ sess = session.InteractiveSession()
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ self.assertAllEqual([[4.0, 4.0, 4.0]], c.eval())
+ d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
+ e = math_ops.matmul(c, d)
+ self.assertAllEqual([[24.0]], e.eval())
+ sess.close()
+
+ def testSharedGraph(self):
+ with ops.Graph().as_default() as g, ops.device('/cpu:0'):
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+
+ with session.Session(graph=g) as sess1:
+ with session.Session(graph=g) as sess2:
+ self.assertAllEqual(sess1.run(c), sess2.run(c))
+
+ def testDuplicatedInputs(self):
+ with session.Session() as sess:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[1, 3])
+ a_val, b_val, a2_val = sess.run([a, b, a])
+ self.assertAllEqual(a_val, [[1.0, 1.0]])
+ self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]])
+ self.assertAllEqual(a2_val, [[1.0, 1.0]])
+
+ def testFeedAndFetch(self):
+ with session.Session():
+ for dtype in [types.float32,
+ types.float64,
+ types.int32,
+ types.uint8,
+ types.int16,
+ types.int8,
+ types.int64,
+ types.bool,
+ types.complex64]:
+ for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
+ np_dtype = dtype.as_numpy_dtype
+
+ feed_t = array_ops.placeholder(dtype=dtype, shape=shape)
+ out_t = array_ops.identity(feed_t)
+
+ np_array = np.random.randint(-10, 10, shape)
+
+ if dtype == types.bool:
+ np_array = np_array > 0
+ elif dtype == types.complex64:
+ np_array = np.sqrt(np_array.astype(np_dtype))
+ else:
+ np_array = np_array.astype(np_dtype)
+
+ self.assertAllEqual(np_array,
+ out_t.eval(feed_dict={feed_t: np_array}))
+
+ def testStringFetch(self):
+ with session.Session():
+ for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
+ size = 1
+ for s in shape:
+ size *= s
+ c_list = np.array([str(i) for i in xrange(size)],
+ dtype=np.object).reshape(shape) if size > 0 else []
+ c = constant_op.constant(c_list)
+ self.assertAllEqual(c.eval(), c_list)
+
+ def testStringFeed(self):
+ with session.Session():
+ for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
+ size = 1
+ for s in shape:
+ size *= s
+ c_list = np.array([str(i) for i in xrange(size)],
+ dtype=np.object).reshape(shape)
+ feed_t = array_ops.placeholder(dtype=types.string, shape=shape)
+ c = array_ops.identity(feed_t)
+ self.assertAllEqual(c.eval(feed_dict={feed_t: c_list}), c_list)
+
+ def testStringFeedWithNullCharacters(self):
+ with session.Session():
+ c_list = ['\n\x01\x00', '\n\x00\x01']
+ feed_t = array_ops.placeholder(dtype=types.string, shape=[2])
+ c = array_ops.identity(feed_t)
+ out = c.eval(feed_dict={feed_t: c_list})
+ self.assertEqual(c_list[0], out[0])
+ self.assertEqual(c_list[1], out[1])
+
+ def testInvalidTargetFails(self):
+ with self.assertRaises(RuntimeError):
+ session.Session("INVALID_TARGET")
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/client/tensorflow_server.i b/tensorflow/python/client/tensorflow_server.i
new file mode 100644
index 0000000000..65b3826961
--- /dev/null
+++ b/tensorflow/python/client/tensorflow_server.i
@@ -0,0 +1,16 @@
+%include "tensorflow/python/platform/base.i"
+%import(module="tensorflow.python.pywrap_tensorflow") "tensorflow/python/lib/core/status.i"
+
+%{
+#include "tensorflow/core/public/tensorflow_server.h"
+%}
+
+%ignoreall
+
+%unignore tensorflow;
+%unignore tensorflow::LaunchTensorFlow;
+
+%include "tensorflow/core/public/tensorflow_server.h"
+
+%unignoreall
+
diff --git a/tensorflow/python/client/test_construction_fails_op.cc b/tensorflow/python/client/test_construction_fails_op.cc
new file mode 100644
index 0000000000..47b2b5b49c
--- /dev/null
+++ b/tensorflow/python/client/test_construction_fails_op.cc
@@ -0,0 +1,22 @@
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("ConstructionFails");
+
+class ConstructionFailsOp : public OpKernel {
+ public:
+ explicit ConstructionFailsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES(ctx, false,
+ errors::InvalidArgument("Failure during construction."));
+ }
+
+ void Compute(OpKernelContext* ctx) override {}
+};
+
+REGISTER_KERNEL_BUILDER(Name("ConstructionFails").Device(DEVICE_CPU),
+ ConstructionFailsOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
new file mode 100644
index 0000000000..30e80f779f
--- /dev/null
+++ b/tensorflow/python/client/tf_session.i
@@ -0,0 +1,235 @@
+%include "tensorflow/python/platform/base.i"
+
+%{
+
+#include "numpy/arrayobject.h"
+
+#include "tensorflow/python/client/tf_session_helper.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+%}
+
+// Implements the StatusNotOK exception.
+%import(module="tensorflow.python.pywrap_tensorflow") "tensorflow/python/lib/core/status.i"
+
+// Required to use PyArray_* functions.
+%include "tensorflow/python/platform/numpy.i"
+%init %{
+import_array();
+%}
+
+// Release the Python GIL for the duration of most methods.
+%exception {
+ Py_BEGIN_ALLOW_THREADS;
+ $action
+ Py_END_ALLOW_THREADS;
+}
+
+// Proto input arguments to C API functions are passed as a (const
+// void*, size_t) pair. In Python, typemap these to a single string
+// argument.
+%typemap(in) (const void* proto, size_t proto_len) {
+ char* c_string;
+ Py_ssize_t py_size;
+ if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
+ // Python has raised an error (likely TypeError or UnicodeEncodeError).
+ SWIG_fail;
+ }
+ $1 = static_cast<void*>(c_string);
+ $2 = static_cast<size_t>(py_size);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper()
+////////////////////////////////////////////////////////////////////////////////
+
+// The wrapper takes a vector of pairs of feed names and feed
+// values. In Python this is represented as dictionary mapping strings
+// to numpy arrays.
+%typemap(in) const tensorflow::FeedVector& inputs (
+ tensorflow::FeedVector temp,
+ tensorflow::Safe_PyObjectPtr temp_string_list(tensorflow::make_safe(nullptr)),
+ tensorflow::Safe_PyObjectPtr temp_array_list(tensorflow::make_safe(nullptr))) {
+ if (!PyDict_Check($input)) {
+ SWIG_fail;
+ }
+
+ temp_string_list = tensorflow::make_safe(PyList_New(0));
+ if (!temp_string_list) {
+ SWIG_fail;
+ }
+ temp_array_list = tensorflow::make_safe(PyList_New(0));
+ if (!temp_array_list) {
+ SWIG_fail;
+ }
+
+ PyObject* key;
+ PyObject* value;
+ Py_ssize_t pos = 0;
+ while (PyDict_Next($input, &pos, &key, &value)) {
+ const char* key_string = PyString_AsString(key);
+ if (!key_string) {
+ SWIG_fail;
+ }
+
+ // The ndarray must be stored as contiguous bytes in C (row-major) order.
+ PyObject* array_object = PyArray_FromAny(
+ value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr);
+ if (!array_object) {
+ SWIG_fail;
+ }
+ PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_object);
+
+ // Keep a reference to the key and the array, in case the incoming dict is
+ // modified, and/or to avoid leaking references on failure.
+ if (PyList_Append(temp_string_list.get(), key) == -1) {
+ SWIG_fail;
+ }
+ if (PyList_Append(temp_array_list.get(), array_object) == -1) {
+ SWIG_fail;
+ }
+
+ temp.push_back(std::make_pair(key_string, array));
+ }
+ $1 = &temp;
+}
+
+// The wrapper also takes a list of fetch and target names. In Python this is
+// represented as a list of strings.
+%typemap(in) const tensorflow::NameVector& (
+ tensorflow::NameVector temp,
+ tensorflow::Safe_PyObjectPtr temp_string_list(tensorflow::make_safe(nullptr))) {
+ if (!PyList_Check($input)) {
+ SWIG_fail;
+ }
+
+ Py_ssize_t len = PyList_Size($input);
+
+ temp_string_list = tensorflow::make_safe(PyList_New(len));
+ if (!temp_string_list) {
+ SWIG_fail;
+ }
+
+ for (Py_ssize_t i = 0; i < len; ++i) {
+ PyObject* elem = PyList_GetItem($input, i);
+ if (!elem) {
+ SWIG_fail;
+ }
+
+ // Keep a reference to the string in case the incoming list is modified.
+ PyList_SET_ITEM(temp_string_list.get(), i, elem);
+ Py_INCREF(elem);
+
+ const char* fetch_name = PyString_AsString(elem);
+ if (!fetch_name) {
+ PyErr_SetString(PyExc_TypeError,
+ "a fetch or target name was not a string");
+ SWIG_fail;
+ }
+
+ // TODO(mrry): Avoid copying the fetch name in, if this impacts performance.
+ temp.push_back(fetch_name);
+ }
+ $1 = &temp;
+}
+
+
+// The wrapper has two outputs: a tensorflow::Status, and a vector of
+// PyObjects containing the fetch results (iff the status is OK). Since
+// the interpretation of the vector depends on the status, we define
+// them as two consecutive out arguments, so that they can be accessed
+// together in a typemap.
+
+// Define temporaries for the argout outputs.
+%typemap(in, numinputs=0) tensorflow::Status* out_status (
+ tensorflow::Status temp) {
+ $1 = &temp;
+}
+%typemap(in, numinputs=0) tensorflow::PyObjectVector* out_values (
+ tensorflow::PyObjectVector temp) {
+ $1 = &temp;
+}
+
+// Raise a StatusNotOK exception if the out_status is not OK;
+// otherwise build a Python list of outputs and return it.
+%typemap(argout, fragment="StatusNotOK") (
+ tensorflow::Status* out_status, tensorflow::PyObjectVector* out_values) {
+ if (!$1->ok()) {
+ RaiseStatusNotOK(*$1, $descriptor(tensorflow::Status*));
+ SWIG_fail;
+ } else {
+ tensorflow::Safe_PyObjectVector out_values_safe;
+ for (int i = 0; i < $2->size(); ++i) {
+ out_values_safe.emplace_back(tensorflow::make_safe($2->at(i)));
+ }
+
+ $result = PyList_New($2->size());
+ if (!$result) {
+ SWIG_fail;
+ }
+
+ for (int i = 0; i < $2->size(); ++i) {
+ PyList_SET_ITEM($result, i, $2->at(i));
+ out_values_safe[i].release();
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// END TYPEMAPS FOR tensorflow::TF_Run_wrapper()
+////////////////////////////////////////////////////////////////////////////////
+
+
+
+// Include the functions from tensor_c_api.h, except TF_Run.
+%ignoreall
+%unignore TF_Code;
+%unignore TF_Status;
+%unignore TF_NewStatus;
+%unignore TF_DeleteStatus;
+%unignore TF_GetCode;
+%unignore TF_Message;
+%unignore TF_SessionOptions;
+%rename("_TF_SetTarget") TF_SetTarget;
+%rename("_TF_SetConfig") TF_SetConfig;
+%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
+%unignore TF_DeleteSessionOptions;
+%unignore TF_NewSession;
+%unignore TF_CloseSession;
+%unignore TF_DeleteSession;
+%unignore TF_ExtendGraph;
+%include "tensorflow/core/public/tensor_c_api.h"
+%ignoreall
+
+%insert("python") %{
+ def TF_NewSessionOptions(target=None, config=None):
+ opts = _TF_NewSessionOptions()
+ if target is not None:
+ _TF_SetTarget(opts, target)
+ if config is not None:
+ from tensorflow.core.framework import config_pb2
+ if not isinstance(config, config_pb2.ConfigProto):
+ raise TypeError("Expected config_pb2.ConfigProto, "
+ "but got %s" % type(config))
+ status = TF_NewStatus()
+ config_str = config.SerializeToString()
+ _TF_SetConfig(opts, config_str, len(config_str), status)
+ if TF_GetCode(status) != 0:
+ raise ValueError(TF_Message(status))
+ return opts
+%}
+
+// Include the wrapper for TF_Run from tf_session_helper.h.
+
+// The %exception block above releases the Python GIL for the length
+// of each wrapped method. We disable this behavior for TF_Run
+// because it uses the Python allocator.
+%noexception tensorflow::TF_Run_wrapper;
+%rename(TF_Run) tensorflow::TF_Run_wrapper;
+%unignore tensorflow;
+%unignore TF_Run;
+
+%include "tensorflow/python/client/tf_session_helper.h"
+
+%unignoreall
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
new file mode 100644
index 0000000000..06483da87b
--- /dev/null
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -0,0 +1,518 @@
+#include "tensorflow/python/client/tf_session_helper.h"
+
+#include <cstring>
+
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Container types for the various temporary values used internally in
+// the wrapper.
+
+// A TF_TensorVector is a vector of borrowed pointers to TF_Tensors.
+typedef gtl::InlinedVector<TF_Tensor*, 8> TF_TensorVector;
+
+// Safe containers for (an) owned TF_Tensor(s). On destruction, the
+// tensor will be deleted by TF_DeleteTensor.
+typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
+ Safe_TF_TensorPtr;
+typedef std::vector<Safe_TF_TensorPtr> Safe_TF_TensorVector;
+Safe_TF_TensorPtr make_safe(TF_Tensor* tensor) {
+ return Safe_TF_TensorPtr(tensor, TF_DeleteTensor);
+}
+
+// Safe container for an owned TF_Status. On destruction, the status
+// will be deleted by TF_DeleteStatus.
+typedef std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>
+ Safe_TF_StatusPtr;
+Safe_TF_StatusPtr make_safe(TF_Status* status) {
+ return Safe_TF_StatusPtr(status, TF_DeleteStatus);
+}
+
+Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
+ TF_DataType* out_tf_datatype) {
+ PyObject* key;
+ PyObject* value;
+ Py_ssize_t pos = 0;
+ if (PyDict_Next(descr->fields, &pos, &key, &value)) {
+ const char* key_string = PyString_AsString(key);
+ if (!key_string) {
+ return errors::Internal("Corrupt numpy type descriptor");
+ }
+ tensorflow::string key = key_string;
+ // The typenames here should match the field names in the custom struct
+ // types constructed in test_util.py.
+ // TODO(mrry,keveman): Investigate Numpy type registration to replace this
+ // hard-coding of names.
+ if (key == "quint8") {
+ *out_tf_datatype = TF_QUINT8;
+ } else if (key == "qint8") {
+ *out_tf_datatype = TF_QINT8;
+ } else if (key == "qint32") {
+ *out_tf_datatype = TF_QINT32;
+ } else {
+ return errors::Internal("Unsupported numpy data type");
+ }
+ return Status::OK();
+ }
+ return errors::Internal("Unsupported numpy data type");
+}
+
+Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
+ TF_DataType* out_tf_datatype) {
+ int pyarray_type = PyArray_TYPE(array);
+ PyArray_Descr* descr = array->descr;
+ switch (pyarray_type) {
+ case NPY_FLOAT32:
+ *out_tf_datatype = TF_FLOAT;
+ break;
+ case NPY_FLOAT64:
+ *out_tf_datatype = TF_DOUBLE;
+ break;
+ case NPY_INT32:
+ *out_tf_datatype = TF_INT32;
+ break;
+ case NPY_UINT8:
+ *out_tf_datatype = TF_UINT8;
+ break;
+ case NPY_INT16:
+ *out_tf_datatype = TF_INT16;
+ break;
+ case NPY_INT8:
+ *out_tf_datatype = TF_INT8;
+ break;
+ case NPY_INT64:
+ *out_tf_datatype = TF_INT64;
+ break;
+ case NPY_BOOL:
+ *out_tf_datatype = TF_BOOL;
+ break;
+ case NPY_COMPLEX64:
+ *out_tf_datatype = TF_COMPLEX;
+ break;
+ case NPY_OBJECT:
+ *out_tf_datatype = TF_STRING;
+ break;
+ case NPY_VOID:
+ // Quantized types are currently represented as custom struct types.
+ // PyArray_TYPE returns NPY_VOID for structs, and we should look into
+ // descr to derive the actual type.
+ return PyArrayDescr_to_TF_DataType(descr, out_tf_datatype);
+ default:
+ // TODO(mrry): Support these.
+ return errors::Internal("Unsupported feed type");
+ }
+ return Status::OK();
+}
+
+Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
+ int* out_pyarray_type) {
+ switch (tf_datatype) {
+ case TF_FLOAT:
+ *out_pyarray_type = NPY_FLOAT32;
+ break;
+ case TF_DOUBLE:
+ *out_pyarray_type = NPY_FLOAT64;
+ break;
+ case TF_INT32:
+ *out_pyarray_type = NPY_INT32;
+ break;
+ case TF_UINT8:
+ *out_pyarray_type = NPY_UINT8;
+ break;
+ case TF_INT16:
+ *out_pyarray_type = NPY_INT16;
+ break;
+ case TF_INT8:
+ *out_pyarray_type = NPY_INT8;
+ break;
+ case TF_INT64:
+ *out_pyarray_type = NPY_INT64;
+ break;
+ case TF_BOOL:
+ *out_pyarray_type = NPY_BOOL;
+ break;
+ case TF_COMPLEX:
+ *out_pyarray_type = NPY_COMPLEX64;
+ break;
+ case TF_STRING:
+ *out_pyarray_type = NPY_OBJECT;
+ break;
+ // TODO(keveman): These should be changed to NPY_VOID, and the type used for
+ // the resulting numpy array should be the custom struct types that we
+ // expect for quantized types.
+ case TF_QINT8:
+ *out_pyarray_type = NPY_INT8;
+ break;
+ case TF_QUINT8:
+ *out_pyarray_type = NPY_UINT8;
+ break;
+ case TF_QINT32:
+ *out_pyarray_type = NPY_INT32;
+ break;
+ case TF_BFLOAT16:
+ *out_pyarray_type = NPY_UINT16;
+ break;
+ default:
+ return errors::Internal("Unsupported fetch type");
+ }
+ return Status::OK();
+}
+
+// Iterate over the string array 'array', extract the ptr and len of each string
+// element and call f(ptr, len).
+template <typename F>
+Status PyStringArrayMap(PyArrayObject* array, F f) {
+ Safe_PyObjectPtr iter = tensorflow::make_safe(
+ PyArray_IterNew(reinterpret_cast<PyObject*>(array)));
+ while (PyArray_ITER_NOTDONE(iter.get())) {
+ auto item = tensorflow::make_safe(
+ PyArray_GETITEM(array, PyArray_ITER_DATA(iter.get())));
+ if (!item.get()) {
+ return errors::Internal("Unable to get element from the feed.");
+ }
+ char* ptr;
+ Py_ssize_t len;
+ int success = PyString_AsStringAndSize(item.get(), &ptr, &len);
+ if (success != 0) {
+ return errors::Internal("Unable to get element from the feed.");
+ }
+ f(ptr, len);
+ PyArray_ITER_NEXT(iter.get());
+ }
+ return Status::OK();
+}
+
+// Encode the strings in 'array' into a contiguous buffer and return the base of
+// the buffer. The caller takes ownership of the buffer.
+Status EncodePyStringArray(PyArrayObject* array, tensorflow::int64 nelems,
+ size_t* size, void** buffer) {
+ // Compute bytes needed for encoding.
+ *size = 0;
+ TF_RETURN_IF_ERROR(
+ PyStringArrayMap(array, [&size](char* ptr, Py_ssize_t len) {
+ *size += sizeof(tensorflow::uint64) +
+ tensorflow::core::VarintLength(len) + len;
+ }));
+ // Encode all strings.
+ std::unique_ptr<char[]> base_ptr(new char[*size]);
+ char* base = base_ptr.get();
+ char* data_start = base + sizeof(tensorflow::uint64) * nelems;
+ char* dst = data_start; // Where next string is encoded.
+ tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
+
+ TF_RETURN_IF_ERROR(PyStringArrayMap(
+ array, [&base, &data_start, &dst, &offsets](char* ptr, Py_ssize_t len) {
+ *offsets = (dst - data_start);
+ offsets++;
+ dst = tensorflow::core::EncodeVarint64(dst, len);
+ memcpy(dst, ptr, len);
+ dst += len;
+ }));
+ CHECK_EQ(dst, base + *size);
+ *buffer = base_ptr.release();
+ return Status::OK();
+}
+
+// Determine the pointer and offset of the string at offset 'i' in the string
+// tensor 'src', whose total length is 'num_elements'.
+static Status TF_StringTensor_GetPtrAndLen(const TF_Tensor* src,
+ tensorflow::int64 num_elements,
+ tensorflow::int64 i,
+ const char** ptr,
+ tensorflow::uint64* len) {
+ const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
+ const size_t src_size = TF_TensorByteSize(src);
+ const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
+ const char* limit = input + src_size;
+ tensorflow::uint64 offset =
+ reinterpret_cast<const tensorflow::uint64*>(input)[i];
+ const char* p =
+ tensorflow::core::GetVarint64Ptr(data_start + offset, limit, len);
+ if (offset >= (limit - data_start) || !p || (*len > (limit - p))) {
+ return errors::InvalidArgument("Malformed TF_STRING tensor; element ", i,
+ " out of range");
+ }
+ *ptr = p;
+ return Status::OK();
+}
+
+// Copy the string at offset 'i' in the (linearized) string tensor 'tensor' into
+// 'pyarray' at offset pointed by the 'i_ptr' iterator.
+static Status CopyStringToPyArrayElement(PyArrayObject* pyarray, void* i_ptr,
+ TF_Tensor* tensor,
+ tensorflow::int64 num_elements,
+ tensorflow::int64 i) {
+ const char* ptr;
+ tensorflow::uint64 len;
+ TF_RETURN_IF_ERROR(
+ TF_StringTensor_GetPtrAndLen(tensor, num_elements, i, &ptr, &len));
+ auto py_string = tensorflow::make_safe(PyString_FromStringAndSize(ptr, len));
+ int success =
+ PyArray_SETITEM(pyarray, PyArray_ITER_DATA(i_ptr), py_string.get());
+ if (success != 0) {
+ return errors::Internal("Error setting element ", i);
+ }
+ return Status::OK();
+}
+
+// Converts the given TF_Tensor to a Numpy array.
+// If the returned status is OK, the caller becomes the owner of *out_array.
+Status TF_Tensor_to_PyObject(TF_Tensor* tensor, PyObject** out_array) {
+ // A fetched operation will correspond to a null tensor, and a None
+ // in Python.
+ if (tensor == nullptr) {
+ Py_INCREF(Py_None);
+ *out_array = Py_None;
+ return Status::OK();
+ }
+
+ const int ndims = TF_NumDims(tensor);
+ gtl::InlinedVector<npy_intp, 4> dims(ndims);
+ tensorflow::int64 nelems = 1;
+ for (int i = 0; i < ndims; ++i) {
+ dims[i] = TF_Dim(tensor, i);
+ nelems *= dims[i];
+ }
+
+ // Convert TensorFlow dtype to numpy type descriptor.
+ int type_num;
+ TF_RETURN_IF_ERROR(
+ TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num));
+ PyArray_Descr* descr = PyArray_DescrFromType(type_num);
+
+ // Copy the TF_TensorData into a newly-created ndarray and return it.
+ // TODO(mrry): Perhaps investigate zero-copy approaches. This would involve
+ // creating an ndarray-like object that wraps the TF_Tensor buffer, and
+ // maps its destructor to TF_DeleteTensor.
+ Safe_PyObjectPtr safe_out_array =
+ tensorflow::make_safe(PyArray_Empty(ndims, dims.data(), descr, 0));
+ if (!safe_out_array) {
+ return errors::Internal("Could not allocate ndarray");
+ }
+ PyArrayObject* py_array =
+ reinterpret_cast<PyArrayObject*>(safe_out_array.get());
+ if (PyArray_NBYTES(py_array) != TF_TensorByteSize(tensor)) {
+ if (TF_TensorType(tensor) == TF_STRING) {
+ // Copy element by element.
+ auto iter = tensorflow::make_safe(PyArray_IterNew(safe_out_array.get()));
+ for (tensorflow::int64 i = 0; i < nelems; ++i) {
+ auto s =
+ CopyStringToPyArrayElement(py_array, iter.get(), tensor, nelems, i);
+ if (!s.ok()) {
+ return s;
+ }
+ PyArray_ITER_NEXT(iter.get());
+ }
+ } else {
+ return errors::Internal("ndarray was ", PyArray_NBYTES(py_array),
+ " bytes but TF_Tensor was ",
+ TF_TensorByteSize(tensor), " bytes");
+ }
+ } else {
+ memcpy(py_array->data, TF_TensorData(tensor), PyArray_NBYTES(py_array));
+ }
+
+ // PyArray_Return turns rank 0 arrays into numpy scalars
+ *out_array = PyArray_Return(
+ reinterpret_cast<PyArrayObject*>(safe_out_array.release()));
+ return Status::OK();
+}
+
+tensorflow::Status TF_Status_to_Status(TF_Status* tf_status) {
+ TF_Code code = TF_GetCode(tf_status);
+ const string message(TF_Message(tf_status));
+
+ switch (code) {
+ case TF_OK:
+ return Status::OK();
+ case TF_CANCELLED:
+ return errors::Cancelled(message);
+ case TF_UNKNOWN:
+ return errors::Unknown(message);
+ case TF_INVALID_ARGUMENT:
+ return errors::InvalidArgument(message);
+ case TF_DEADLINE_EXCEEDED:
+ return errors::DeadlineExceeded(message);
+ case TF_NOT_FOUND:
+ return errors::NotFound(message);
+ case TF_ALREADY_EXISTS:
+ return errors::AlreadyExists(message);
+ case TF_PERMISSION_DENIED:
+ return errors::PermissionDenied(message);
+ case TF_UNAUTHENTICATED:
+ return errors::Unauthenticated(message);
+ case TF_RESOURCE_EXHAUSTED:
+ return errors::ResourceExhausted(message);
+ case TF_FAILED_PRECONDITION:
+ return errors::FailedPrecondition(message);
+ case TF_ABORTED:
+ return errors::Aborted(message);
+ case TF_OUT_OF_RANGE:
+ return errors::OutOfRange(message);
+ case TF_UNIMPLEMENTED:
+ return errors::Unimplemented(message);
+ case TF_INTERNAL:
+ return errors::Internal(message);
+ case TF_UNAVAILABLE:
+ return errors::Unavailable(message);
+ case TF_DATA_LOSS:
+ return errors::DataLoss(message);
+ default:
+ return errors::Internal("Got error with unknown code: ", code, " ",
+ message);
+ }
+}
+
+static bool numpy_imported = false;
+
+} // namespace
+
+Safe_PyObjectPtr make_safe(PyObject* o) {
+ return Safe_PyObjectPtr(o, Py_DECREF_wrapper);
+}
+
+// Wrapper for TF_Run that converts the arguments to appropriate types.
+// If *out_status is OK, the caller becomes the owner of the PyObjects
+// in *out_values.
+void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
+ const NameVector& output_names,
+ const NameVector& target_nodes, Status* out_status,
+ PyObjectVector* out_values) {
+ // 0. Ensure that numpy has been imported.
+ if (!numpy_imported) {
+ import_array();
+ numpy_imported = true;
+ }
+
+ // 1. Convert the feed inputs to the appropriate form for TF_Run.
+ NameVector input_names;
+ Safe_PyObjectVector
+ py_inputs_safe; // Used to decref the input arrays on failure.
+ Safe_TF_TensorVector inputs_safe; // Used to delete tensors on failure.
+ TF_TensorVector inputs_unsafe; // Used to contain the arg to TF_Run.
+
+ for (const auto& name_and_array : inputs) {
+ py_inputs_safe.emplace_back(
+ make_safe(reinterpret_cast<PyObject*>(name_and_array.second)));
+ }
+
+ for (int i = 0; i < inputs.size(); ++i) {
+ input_names.push_back(inputs[i].first);
+ PyArrayObject* array = inputs[i].second;
+
+ // Convert numpy dtype to TensorFlow dtype.
+ TF_DataType dtype;
+ *out_status = PyArray_TYPE_to_TF_DataType(array, &dtype);
+ if (!out_status->ok()) {
+ return;
+ }
+
+ tensorflow::int64 nelems = 1;
+ gtl::InlinedVector<tensorflow::int64, 4> dims;
+ for (int i = 0; i < PyArray_NDIM(array); ++i) {
+ dims.push_back(PyArray_SHAPE(array)[i]);
+ nelems *= dims[i];
+ }
+
+ // Create a TF_Tensor based on the fed data. In the case of non-string data
+ // type, this steals a reference to array, which will be relinquished when
+ // the underlying buffer is deallocated. For string, a new temporary buffer
+ // is allocated into which the strings are encoded.
+ if (dtype != TF_STRING) {
+ // NOTE(mrry): We currently copy the numpy array into a new
+ // buffer to avoid possible issues on deallocation (such as
+ // having to acquire the Python Global Interpreter Lock).
+ // TODO(mrry): Investigate in what cases we can safely acquire
+ size_t size = PyArray_NBYTES(array);
+ // NOTE(mrry): 32 is the upper bound on current alignment
+ // requirements for tensorflow::Tensor. We hard code this here to
+ // avoid taking a dependency on Eigen in the client code.
+ void* data = tensorflow::cpu_allocator()->AllocateRaw(32, size);
+ std::memcpy(data, array->data, size);
+ inputs_safe.emplace_back(make_safe(
+ TF_NewTensor(dtype, dims.data(), dims.size(), data, size,
+ [](void* data, size_t len, void* arg) {
+ tensorflow::cpu_allocator()->DeallocateRaw(data);
+ },
+ nullptr)));
+ // The destruction of the numpy array will now be handled by the
+ // inputs_safe destructor.
+ py_inputs_safe[i].reset();
+ } else {
+ size_t size;
+ void* encoded;
+ Status s = EncodePyStringArray(array, nelems, &size, &encoded);
+ if (!s.ok()) {
+ *out_status = s;
+ return;
+ }
+ inputs_safe.emplace_back(
+ make_safe(TF_NewTensor(dtype, dims.data(), dims.size(), encoded, size,
+ [](void* data, size_t len, void* arg) {
+ delete[] reinterpret_cast<char*>(data);
+ },
+ array)));
+ // The destruction of the numpy array will now be handled by the
+ // inputs_safe destructor.
+ py_inputs_safe[i].reset();
+ }
+ inputs_unsafe.push_back(inputs_safe.back().get());
+ }
+
+ // 2. Allocate a container for the output data.
+ TF_TensorVector outputs(output_names.size());
+
+ Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
+
+ // 3. Actually call TF_Run().
+ Py_BEGIN_ALLOW_THREADS;
+ TF_Run(session, input_names.data(), inputs_unsafe.data(), input_names.size(),
+ const_cast<const char**>(output_names.data()), outputs.data(),
+ output_names.size(), const_cast<const char**>(target_nodes.data()),
+ target_nodes.size(), status.get());
+ Py_END_ALLOW_THREADS;
+
+ // 4. The TensorFlow runtime has taken ownership of the fed tensors,
+ // so we release the safe pointers to them.
+ for (auto& input : inputs_safe) {
+ input.release();
+ }
+
+ if (TF_GetCode(status.get()) != TF_OK) {
+ *out_status = TF_Status_to_Status(status.get());
+ return;
+ }
+
+ // 5. We now own the fetched tensors, so set up a safe container to
+ // delete them when we exit this scope.
+ Safe_TF_TensorVector tf_outputs_safe;
+ for (const auto& output : outputs) {
+ tf_outputs_safe.emplace_back(make_safe(output));
+ }
+
+ // 6. Convert the fetched tensors into numpy ndarrays. Store them in a safe
+ // container so that we do not leak
+ Safe_PyObjectVector py_outputs_safe;
+ for (int i = 0; i < output_names.size(); ++i) {
+ PyObject* py_array;
+ *out_status = TF_Tensor_to_PyObject(outputs[i], &py_array);
+ if (!out_status->ok()) {
+ return;
+ }
+ py_outputs_safe.emplace_back(make_safe(py_array));
+ }
+
+ // 7. If we reach this point, we have successfully built a list of objects
+ // so we can release them from the safe container.
+ for (auto& output : py_outputs_safe) {
+ out_values->push_back(output.release());
+ }
+ *out_status = Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
new file mode 100644
index 0000000000..12a7527ed9
--- /dev/null
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -0,0 +1,56 @@
+#ifndef TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
+#define TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
+
+#include <Python.h>
+
+#include "numpy/arrayobject.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor_c_api.h"
+
+namespace tensorflow {
+
+// Container types for the various arguments and temporary values used
+// in the wrapper.
+
+// A FeedVector is a vector of tensor name and numpy array pairs. The
+// name is a borrowed C string.
+typedef tensorflow::gtl::InlinedVector<std::pair<const char*, PyArrayObject*>,
+ 8> FeedVector;
+
+// A NameVector is a vector of tensor or operation names, as borrowed
+// C strings.
+typedef tensorflow::gtl::InlinedVector<const char*, 8> NameVector;
+
+// A PyObjectVector is a vector of borrowed pointers to PyObjects.
+typedef tensorflow::gtl::InlinedVector<PyObject*, 8> PyObjectVector;
+
+// Safe containers for (an) owned PyObject(s). On destruction, the
+// reference count of the contained object will be decremented.
+inline void Py_DECREF_wrapper(PyObject* o) { Py_DECREF(o); }
+typedef void (*Py_DECREF_wrapper_type)(PyObject*);
+typedef std::unique_ptr<PyObject, Py_DECREF_wrapper_type> Safe_PyObjectPtr;
+typedef std::vector<Safe_PyObjectPtr> Safe_PyObjectVector;
+Safe_PyObjectPtr make_safe(PyObject* o);
+
+// Run the graph associated with the session starting with the
+// supplied inputs[]. Regardless of success of failure, inputs[] are
+// stolen by the implementation (i.e. the implementation will
+// eventually call Py_DECREF on each array input).
+//
+// On success, the tensors corresponding to output_names[0,noutputs-1]
+// are placed in out_values[], and these outputs[] become the property
+// of the caller (the caller must eventually call Py_DECREF on them).
+//
+// On failure, out_status contains a tensorflow::Status with an error
+// message.
+void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs,
+ const NameVector& output_names,
+ const NameVector& target_nodes, Status* out_status,
+ PyObjectVector* out_values);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_