aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-05-10 11:39:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-10 17:41:08 -0700
commitb9cdcbe3a019f09eee7e7c9c9039e4533931e1e7 (patch)
tree9c6c1689c9a3eab2cecc8aaa54790c23c5a3345c
parent9472ba05620332fb274e643e08228b5c530e3e52 (diff)
Create a TF_Graph alongside the Python graph.
This is a first step towards porting the Python API to use the C API. As the Python Graph and Operations are constructed, an analogous TF_Graph and TF_Operations are created via SWIG. Currently nothing is done with the TF_Graph; a next step will be switching to the new TF_Session API which runs a TF_Graph directly (instead of a GraphDef). This new functionality is disabled by default and can be manually enabled by setting the _USE_C_API global in ops.py. For this patch I only enabled it for a single test file. I tried enabling it for all TF Python tests and manually disabling it for unsupported tests, but there were too many failing tests (although most tests passed). See ops.py for a TODO list of unsupported functionality. I benchmarked building an Inception model, and building the TF_Graph incurs a 20% overhead to the total graph construction time. Note that this patch does not remove any existing Python functionality; another next step will be recovering this time by removing redundant Python code. There is no measurable overhead with the new functionality disabled. PiperOrigin-RevId: 155655064
-rw-r--r--tensorflow/python/client/tf_session.i69
-rw-r--r--tensorflow/python/framework/ops.py135
-rw-r--r--tensorflow/python/ops/math_ops_test.py2
3 files changed, 182 insertions, 24 deletions
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 7c6f1cdd5e..902f02a256 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -156,35 +156,56 @@ tensorflow::ImportNumpy();
reinterpret_cast<const char*>($1.data), $1.length);
}
-// Include the functions from c_api.h, except TF_Run.
-%ignoreall
-%unignore TF_Code;
-%unignore TF_Status;
-%unignore TF_Buffer;
-%unignore TF_NewBuffer;
-%unignore TF_NewBufferFromString;
-%unignore TF_DeleteBuffer;
-%unignore TF_GetBuffer;
-%unignore TF_NewStatus;
-%unignore TF_DeleteStatus;
-%unignore TF_GetCode;
-%unignore TF_Message;
-%unignore TF_SessionOptions;
+%inline %{
+// Helper function to convert a Python list of Tensors to a C++ vector of
+// TF_Outputs.
+//
+// Caller should have already checked that `py_tensor_list` is a list (this
+// isn't done in this function to allow for function-specific error messages)
+void PyTensorListToVector(PyObject* py_tensor_list,
+ std::vector<TF_Output>* vec) {
+ size_t size = PyList_Size(py_tensor_list);
+ for (int i = 0; i < size; ++i) {
+ PyObject* item = PyList_GetItem(py_tensor_list, i);
+ TF_Output* input_ptr;
+ SWIG_ConvertPtr(item, reinterpret_cast<void**>(&input_ptr),
+ SWIGTYPE_p_TF_Output, 0);
+ vec->push_back(*input_ptr);
+ }
+}
+%}
+
+// Converts input Python list of wrapped TF_Outputs into a single array
+%typemap(in) (const TF_Output* inputs, int num_inputs)
+ (std::vector<TF_Output> inputs) {
+ if (!PyList_Check($input)) {
+ SWIG_exception_fail(
+ SWIG_TypeError, "$symname: expected Python list of wrapped TF_Outputs");
+ }
+ PyTensorListToVector($input, &inputs);
+ $1 = inputs.data();
+ $2 = inputs.size();
+}
+
+// TODO(skyewm): SWIG emits a warning for the const char* in TF_WhileParams,
+// skip for now
+%ignore TF_WhileParams;
+%ignore TF_NewWhile;
+%ignore TF_FinishWhile;
+%ignore TF_AbortWhile;
+
+// These are defined below, avoid duplicate definitions
+%ignore TF_Run;
+%ignore TF_PRun;
+%ignore TF_PRunSetup;
+
%rename("_TF_SetTarget") TF_SetTarget;
%rename("_TF_SetConfig") TF_SetConfig;
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
-%unignore TF_DeleteSessionOptions;
-%unignore TF_NewDeprecatedSession;
-%unignore TF_CloseDeprecatedSession;
-%unignore TF_DeleteDeprecatedSession;
-%unignore TF_ExtendGraph;
-%unignore TF_NewLibrary;
-%unignore TF_LoadLibrary;
-%unignore TF_DeleteLibraryHandle;
-%unignore TF_GetOpList;
+
%include "tensorflow/c/c_api.h"
-%ignoreall
+%ignoreall
%insert("python") %{
def TF_NewSessionOptions(target=None, config=None):
# NOTE: target and config are validated in the session constructor.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 0b04904ec2..2581ac6a3c 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -34,8 +34,10 @@ from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
+from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import registry
from tensorflow.python.framework import tensor_shape
@@ -46,6 +48,24 @@ from tensorflow.python.util import decorator_utils
from tensorflow.python.util import tf_contextlib
+# Temporary global switch determining if we should enable the work-in-progress
+# calls to the C API. Currently disabled by default but can be manually enabled
+# e.g. in tests. This will be removed once all functionality is supported and
+# there's no performance penalty with it enabled.
+#
+# TODO(skyewm) before we can remove this:
+# - functions
+# - import_graph_def() incrementally adds inputs to ops (i.e. creates an
+# Operation and then calls _add_input()). The current code requires that all
+# inputs be specified when creating the Operation (since we call
+# TF_FinishOperation()).
+# - ops_test.py (and others?) create unregistered op types
+# - while loop
+# - performance (e.g. delete/refactor redundant Python functionality, switch to
+# new session API)
+_USE_C_API = False
+
+
def _override_helper(clazz_object, operator, func):
"""Overrides (string) operator on Tensors to call func.
@@ -467,6 +487,13 @@ class Tensor(_TensorLike):
else:
return "%s:%d" % (self._op.name, self._value_index)
+ def _as_tf_output(self):
+ assert self.op._c_op # pylint: disable=protected-access
+ tf_output = c_api.TF_Output()
+ tf_output.oper = self.op._c_op # pylint: disable=protected-access
+ tf_output.index = self.value_index
+ return tf_output
+
def __str__(self):
return "Tensor(\"%s\"%s%s%s)" % (
self.name,
@@ -1252,6 +1279,98 @@ class Operation(object):
self._id_value = self._graph._next_id() # pylint: disable=protected-access
self._recompute_node_def()
+ if _USE_C_API:
+ assert self._graph._c_graph, ( # pylint: disable=protected-access
+ "_USE_C_API set to False when creating Graph, you may need to "
+ "manually set 'ops._USE_C_API = True' before creating the Graph")
+ if self._op_def:
+ # TODO(skyewm): op_def_library.apply_op() flattens the incoming
+ # inputs. Refactor so we don't have to do this here.
+ grouped_inputs = self._reconstruct_sequence_inputs(
+ self._op_def, self._inputs, self._node_def.attr)
+ else:
+ # If no OpDef is specified, assume all inputs are scalar.
+ grouped_inputs = self._inputs
+
+ self._c_op = self._create_c_op(self._graph, self._node_def,
+ grouped_inputs)
+ else:
+ self._c_op = None
+
+ def _create_c_op(self, graph, node_def, inputs):
+ """Creates a TF_Operation.
+
+ Arguments:
+ graph: a `Graph`.
+ node_def: `node_def_pb2.NodeDef` for the operation to create.
+ inputs: A list of `Tensor`s (corresponding to scalar inputs) and lists of
+ `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
+ "list(int64)"). The length of the list should be equal to the number of
+ inputs specified by this operation's op def.
+
+ Returns:
+ A wrapped TF_Operation*.
+ """
+ # pylint: disable=protected-access
+ op_desc = c_api.TF_NewOperation(graph._c_graph.g,
+ compat.as_str(node_def.op),
+ compat.as_str(node_def.name))
+
+ for op_input in inputs:
+ if isinstance(op_input, (list, tuple)):
+ c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input])
+ else:
+ c_api.TF_AddInput(op_desc, op_input._as_tf_output())
+ # pylint: enable=protected-access
+
+ for name, attr_value in node_def.attr.items():
+ serialized = attr_value.SerializeToString()
+ # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
+ # It might be worth creating a convenient way to re-use the same status.
+ with errors.raise_exception_on_not_ok_status() as status:
+ c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized,
+ status)
+
+ with errors.raise_exception_on_not_ok_status() as status:
+ c_op = c_api.TF_FinishOperation(op_desc, status)
+
+ return c_op
+
+ def _reconstruct_sequence_inputs(self, op_def, inputs, attrs):
+ """Regroups a flat list of input tensors into scalar and sequence inputs.
+
+ Arguments:
+ op_def: The `op_def_pb2.OpDef` (for knowing the input types)
+ inputs: a list of input `Tensor`s to the op.
+ attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define
+ how long each sequence is)
+
+ Returns:
+ A list of `Tensor`s (corresponding to scalar inputs) and lists of
+ `Tensor`s (corresponding to sequence inputs).
+ """
+ grouped_inputs = []
+ i = 0
+ for input_arg in op_def.input_arg:
+ if input_arg.number_attr:
+ input_len = attrs[input_arg.number_attr].i
+ is_sequence = True
+ elif input_arg.type_list_attr:
+ input_len = len(attrs[input_arg.type_list_attr].list.type)
+ is_sequence = True
+ else:
+ input_len = 1
+ is_sequence = False
+
+ if is_sequence:
+ grouped_inputs.append(inputs[i:i + input_len])
+ else:
+ grouped_inputs.append(inputs[i])
+ i += input_len
+
+ assert i == len(inputs)
+ return grouped_inputs
+
def colocation_groups(self):
"""Returns the list of colocation groups of the op."""
default_colocation_group = [compat.as_bytes("loc:@%s" %
@@ -1911,6 +2030,15 @@ def _name_from_scope_name(name):
return name[:-1] if name[-1] == "/" else name
+class _ScopedTF_Graph(object):
+
+ def __init__(self):
+ self.g = c_api.TF_NewGraph()
+
+ def __del__(self):
+ c_api.TF_DeleteGraph(self.g)
+
+
class Graph(object):
"""A TensorFlow computation, represented as a dataflow graph.
@@ -2024,6 +2152,13 @@ class Graph(object):
self._container = ""
self._registered_ops = op_def_registry.get_registered_ops()
+ # TODO(skyewm): fold as much of the above as possible into the C
+ # implementation
+ if _USE_C_API:
+ self._c_graph = _ScopedTF_Graph()
+ else:
+ self._c_graph = None
+
def _check_not_finalized(self):
"""Check if the graph is finalized.
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 7dbc8efe16..ed30c4bcfa 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -28,6 +28,8 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
+ops._USE_C_API = True
+
exp = np.exp
log = np.log