diff options
author | 2017-05-10 11:39:52 -0700 | |
---|---|---|
committer | 2017-05-10 17:41:08 -0700 | |
commit | b9cdcbe3a019f09eee7e7c9c9039e4533931e1e7 (patch) | |
tree | 9c6c1689c9a3eab2cecc8aaa54790c23c5a3345c | |
parent | 9472ba05620332fb274e643e08228b5c530e3e52 (diff) |
Create a TF_Graph alongside the Python graph.
This is a first step towards porting the Python API to use the C API. As the Python Graph and Operations are constructed, an analogous TF_Graph and TF_Operations are created via SWIG. Currently nothing is done with the TF_Graph; a next step will be switching to the new TF_Session API which runs a TF_Graph directly (instead of a GraphDef).
This new functionality is disabled by default and can be manually enabled by setting the _USE_C_API global in ops.py. For this patch I only enabled it for a single test file. I tried enabling it for all TF Python tests and manually disabling it for unsupported tests, but there were too many failing tests (although most tests passed). See ops.py for a TODO list of unsupported functionality.
I benchmarked building an Inception model, and building the TF_Graph incurs a 20% overhead to the total graph construction time. Note that this patch does not remove any existing Python functionality; another next step will be recovering this time by removing redundant Python code. There is no measurable overhead with the new functionality disabled.
PiperOrigin-RevId: 155655064
-rw-r--r-- | tensorflow/python/client/tf_session.i | 69 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 135 | ||||
-rw-r--r-- | tensorflow/python/ops/math_ops_test.py | 2 |
3 files changed, 182 insertions, 24 deletions
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 7c6f1cdd5e..902f02a256 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -156,35 +156,56 @@ tensorflow::ImportNumpy(); reinterpret_cast<const char*>($1.data), $1.length); } -// Include the functions from c_api.h, except TF_Run. -%ignoreall -%unignore TF_Code; -%unignore TF_Status; -%unignore TF_Buffer; -%unignore TF_NewBuffer; -%unignore TF_NewBufferFromString; -%unignore TF_DeleteBuffer; -%unignore TF_GetBuffer; -%unignore TF_NewStatus; -%unignore TF_DeleteStatus; -%unignore TF_GetCode; -%unignore TF_Message; -%unignore TF_SessionOptions; +%inline %{ +// Helper function to convert a Python list of Tensors to a C++ vector of +// TF_Outputs. +// +// Caller should have already checked that `py_tensor_list` is a list (this +// isn't done in this function to allow for function-specific error messages) +void PyTensorListToVector(PyObject* py_tensor_list, + std::vector<TF_Output>* vec) { + size_t size = PyList_Size(py_tensor_list); + for (int i = 0; i < size; ++i) { + PyObject* item = PyList_GetItem(py_tensor_list, i); + TF_Output* input_ptr; + SWIG_ConvertPtr(item, reinterpret_cast<void**>(&input_ptr), + SWIGTYPE_p_TF_Output, 0); + vec->push_back(*input_ptr); + } +} +%} + +// Converts input Python list of wrapped TF_Outputs into a single array +%typemap(in) (const TF_Output* inputs, int num_inputs) + (std::vector<TF_Output> inputs) { + if (!PyList_Check($input)) { + SWIG_exception_fail( + SWIG_TypeError, "$symname: expected Python list of wrapped TF_Outputs"); + } + PyTensorListToVector($input, &inputs); + $1 = inputs.data(); + $2 = inputs.size(); +} + +// TODO(skyewm): SWIG emits a warning for the const char* in TF_WhileParams, +// skip for now +%ignore TF_WhileParams; +%ignore TF_NewWhile; +%ignore TF_FinishWhile; +%ignore TF_AbortWhile; + +// These are defined below, avoid duplicate definitions +%ignore TF_Run; +%ignore TF_PRun; +%ignore TF_PRunSetup; + %rename("_TF_SetTarget") TF_SetTarget; %rename("_TF_SetConfig") TF_SetConfig; %rename("_TF_NewSessionOptions") TF_NewSessionOptions; -%unignore TF_DeleteSessionOptions; -%unignore TF_NewDeprecatedSession; -%unignore TF_CloseDeprecatedSession; -%unignore TF_DeleteDeprecatedSession; -%unignore TF_ExtendGraph; -%unignore TF_NewLibrary; -%unignore TF_LoadLibrary; -%unignore TF_DeleteLibraryHandle; -%unignore TF_GetOpList; + %include "tensorflow/c/c_api.h" -%ignoreall +%ignoreall %insert("python") %{ def TF_NewSessionOptions(target=None, config=None): # NOTE: target and config are validated in the session constructor. diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 0b04904ec2..2581ac6a3c 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -34,8 +34,10 @@ from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import versions_pb2 +from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import registry from tensorflow.python.framework import tensor_shape @@ -46,6 +48,24 @@ from tensorflow.python.util import decorator_utils from tensorflow.python.util import tf_contextlib +# Temporary global switch determining if we should enable the work-in-progress +# calls to the C API. Currently disabled by default but can be manually enabled +# e.g. in tests. This will be removed once all functionality is supported and +# there's no performance penalty with it enabled. +# +# TODO(skyewm) before we can remove this: +# - functions +# - import_graph_def() incrementally adds inputs to ops (i.e. creates an +# Operation and then calls _add_input()). The current code requires that all +# inputs be specified when creating the Operation (since we call +# TF_FinishOperation()). +# - ops_test.py (and others?) create unregistered op types +# - while loop +# - performance (e.g. delete/refactor redundant Python functionality, switch to +# new session API) +_USE_C_API = False + + def _override_helper(clazz_object, operator, func): """Overrides (string) operator on Tensors to call func. @@ -467,6 +487,13 @@ class Tensor(_TensorLike): else: return "%s:%d" % (self._op.name, self._value_index) + def _as_tf_output(self): + assert self.op._c_op # pylint: disable=protected-access + tf_output = c_api.TF_Output() + tf_output.oper = self.op._c_op # pylint: disable=protected-access + tf_output.index = self.value_index + return tf_output + def __str__(self): return "Tensor(\"%s\"%s%s%s)" % ( self.name, @@ -1252,6 +1279,98 @@ class Operation(object): self._id_value = self._graph._next_id() # pylint: disable=protected-access self._recompute_node_def() + if _USE_C_API: + assert self._graph._c_graph, ( # pylint: disable=protected-access + "_USE_C_API set to False when creating Graph, you may need to " + "manually set 'ops._USE_C_API = True' before creating the Graph") + if self._op_def: + # TODO(skyewm): op_def_library.apply_op() flattens the incoming + # inputs. Refactor so we don't have to do this here. + grouped_inputs = self._reconstruct_sequence_inputs( + self._op_def, self._inputs, self._node_def.attr) + else: + # If no OpDef is specified, assume all inputs are scalar. + grouped_inputs = self._inputs + + self._c_op = self._create_c_op(self._graph, self._node_def, + grouped_inputs) + else: + self._c_op = None + + def _create_c_op(self, graph, node_def, inputs): + """Creates a TF_Operation. + + Arguments: + graph: a `Graph`. + node_def: `node_def_pb2.NodeDef` for the operation to create. + inputs: A list of `Tensor`s (corresponding to scalar inputs) and lists of + `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N", + "list(int64)"). The length of the list should be equal to the number of + inputs specified by this operation's op def. + + Returns: + A wrapped TF_Operation*. + """ + # pylint: disable=protected-access + op_desc = c_api.TF_NewOperation(graph._c_graph.g, + compat.as_str(node_def.op), + compat.as_str(node_def.name)) + + for op_input in inputs: + if isinstance(op_input, (list, tuple)): + c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input]) + else: + c_api.TF_AddInput(op_desc, op_input._as_tf_output()) + # pylint: enable=protected-access + + for name, attr_value in node_def.attr.items(): + serialized = attr_value.SerializeToString() + # TODO(skyewm): this creates and deletes a new TF_Status for every attr. + # It might be worth creating a convenient way to re-use the same status. + with errors.raise_exception_on_not_ok_status() as status: + c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized, + status) + + with errors.raise_exception_on_not_ok_status() as status: + c_op = c_api.TF_FinishOperation(op_desc, status) + + return c_op + + def _reconstruct_sequence_inputs(self, op_def, inputs, attrs): + """Regroups a flat list of input tensors into scalar and sequence inputs. + + Arguments: + op_def: The `op_def_pb2.OpDef` (for knowing the input types) + inputs: a list of input `Tensor`s to the op. + attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define + how long each sequence is) + + Returns: + A list of `Tensor`s (corresponding to scalar inputs) and lists of + `Tensor`s (corresponding to sequence inputs). + """ + grouped_inputs = [] + i = 0 + for input_arg in op_def.input_arg: + if input_arg.number_attr: + input_len = attrs[input_arg.number_attr].i + is_sequence = True + elif input_arg.type_list_attr: + input_len = len(attrs[input_arg.type_list_attr].list.type) + is_sequence = True + else: + input_len = 1 + is_sequence = False + + if is_sequence: + grouped_inputs.append(inputs[i:i + input_len]) + else: + grouped_inputs.append(inputs[i]) + i += input_len + + assert i == len(inputs) + return grouped_inputs + def colocation_groups(self): """Returns the list of colocation groups of the op.""" default_colocation_group = [compat.as_bytes("loc:@%s" % @@ -1911,6 +2030,15 @@ def _name_from_scope_name(name): return name[:-1] if name[-1] == "/" else name +class _ScopedTF_Graph(object): + + def __init__(self): + self.g = c_api.TF_NewGraph() + + def __del__(self): + c_api.TF_DeleteGraph(self.g) + + class Graph(object): """A TensorFlow computation, represented as a dataflow graph. @@ -2024,6 +2152,13 @@ class Graph(object): self._container = "" self._registered_ops = op_def_registry.get_registered_ops() + # TODO(skyewm): fold as much of the above as possible into the C + # implementation + if _USE_C_API: + self._c_graph = _ScopedTF_Graph() + else: + self._c_graph = None + def _check_not_finalized(self): """Check if the graph is finalized. diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 7dbc8efe16..ed30c4bcfa 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -28,6 +28,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +ops._USE_C_API = True + exp = np.exp log = np.log |