diff options
-rw-r--r-- | tensorflow/python/eager/backprop.py | 2 | ||||
-rw-r--r-- | tensorflow/python/eager/context.py | 15 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 144 | ||||
-rw-r--r-- | tensorflow/python/eager/graph_callable.py | 18 | ||||
-rw-r--r-- | tensorflow/python/eager/graph_callable_test.py | 1 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 30 | ||||
-rw-r--r-- | tensorflow/python/pywrap_tfe.i | 3 |
7 files changed, 143 insertions, 70 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 0144f3b1e5..dc1142705a 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -540,7 +540,7 @@ def _ensure_unique_tensor_objects(parameter_positions, args): if i in parameter_positions: tid = ops.tensor_id(t) if tid in s: - args[i] = args[i]._dup() # pylint: disable=protected-access + args[i] = gen_array_ops.identity(args[i]) else: s.add(tid) return args diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 92f4e15c05..415416cfae 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -288,6 +288,21 @@ class Context(object): self._initialize_handle_and_devices() return self._num_gpus + def add_function(self, fn): + """Add a function definition to the context. + + Once added, the function (identified by its name) can be executed like any + other operation. + + Args: + fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). + """ + with errors.raise_exception_on_not_ok_status() as status: + pywrap_tensorflow.TFE_ContextAddFunction( + self._handle, # pylint: disable=protected-access + fn, + status) + def add_function_def(self, fdef): """Add a function definition to the context. diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 2f4b59e938..092b36ff20 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -25,15 +25,19 @@ import threading import numpy as np +from tensorflow.core.framework import function_pb2 +from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import tape from tensorflow.python.eager.graph_only_ops import graph_placeholder +from tensorflow.python.framework import c_api_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module -from tensorflow.python.framework import graph_to_function_def +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -47,10 +51,41 @@ _scoped_captures = threading.local() _scoped_captures.tensors = None -def make_function_def(graph, operations, inputs, outputs): - """Makes function def from the given graph with the operations.""" - return graph_to_function_def.graph_to_function_def( - graph, operations, inputs, outputs) +def make_function_def(name, graph, operations, inputs, outputs): + """Makes FunctionDef proto and defined function. + + Args: + name: the function name + graph: the graph from which to build the function + operations: the operations in the function body + inputs: tensors to be used as function arguments + outputs: tensors to be returned from the function + + Returns: + fdef: a FunctionDef protocol buffer for the function + fn: a wrapped TF_Function for the function + """ + with errors.raise_exception_on_not_ok_status() as status: + fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( + graph._c_graph, # pylint: disable=protected-access + compat.as_text(name), + False, + [o._c_op for o in operations], # pylint: disable=protected-access + [t._as_tf_output() for t in inputs], # pylint: disable=protected-access + [t._as_tf_output() for t in outputs], # pylint: disable=protected-access + [compat.as_text("%s" % i) for i in range(len(outputs))], + None, + compat.as_text(""), + status) + # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature, + # but also in general it's nice not to depend on it. + with c_api_util.tf_buffer() as buffer_: + with errors.raise_exception_on_not_ok_status() as status: + pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) + proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) + fdef = function_pb2.FunctionDef() + fdef.ParseFromString(compat.as_bytes(proto_data)) + return fdef, fn @contextlib.contextmanager @@ -115,6 +150,10 @@ class CapturingGraph(ops.Graph): # for resource tensors. self._last_op_using_resource_tensor = {} + # TODO(apassos) remove once the C API is used by default. + def _use_c_api_hack(self): + return True + def clear_resource_control_flow_state(self): self._last_op_using_resource_tensor = {} @@ -207,14 +246,20 @@ def _inference_name(n): return "__inference_%s_%s" % (n, ops.uid()) +# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction +# so it doesn't have the definition-generating logic and is just a container for +# an already-defined function. class _DefinedFunction(object): """Mocks the interface of tf _DefinedFunction.""" - def __init__(self, fdef): + def __init__(self, fdef, fn): self.definition = fdef self.name = fdef.signature.name + self.signature = fdef.signature self.grad_func_name = None self.python_grad_func = None + self._c_func = fn + self._grad_func = None def _map_sequence_obj_to_idx(sequence): @@ -250,6 +295,7 @@ class GraphModeFunction(object): input_placeholders, extra_inputs, fdef, + fn, graph, operations, func_outputs, @@ -263,7 +309,7 @@ class GraphModeFunction(object): self._graph = graph self._has_backprop = False self._func_name = fdef.signature.name - self._fdef = _DefinedFunction(fdef) + self._fdef = _DefinedFunction(fdef, fn) self._num_outputs = len(fdef.signature.output_arg) self._ops = operations self._func_outputs = func_outputs @@ -283,38 +329,45 @@ class GraphModeFunction(object): with self._graph.as_default(), context.graph_mode(): c = _CapturingContext() with c: - filtered_outputs = [ - x for x in self._returns if x is not None - ] + filtered_outputs = [x for x in self._returns if x is not None] self._out_grad_placeholders = [ - graph_placeholder(x.dtype, x.shape) for x in filtered_outputs - ] + graph_placeholder(x.dtype, x.shape) for x in filtered_outputs] in_gradients = gradients_impl.gradients( filtered_outputs, self._input_placeholders, grad_ys=self._out_grad_placeholders) - shapes = [x.shape for x in in_gradients if x is not None] + shapes = tuple(x.shape for x in in_gradients if x is not None) captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) - forward_function_def = make_function_def( - self._graph, self._ops, self._input_placeholders, + forward_name = _forward_name(self._func_name) + forward_function_def, forward_fn = make_function_def( + forward_name, self._graph, self._ops, self._input_placeholders, filtered_outputs + captures) - self._forward_fdef = _DefinedFunction(forward_function_def) - _register_with_name(_forward_name(self._func_name), forward_function_def) - backward_outputs = [x for x in in_gradients if x is not None] + self._forward_fdef = _DefinedFunction(forward_function_def, forward_fn) + _register(forward_fn) + backward_outputs = tuple(x for x in in_gradients if x is not None) all_inputs = self._out_grad_placeholders + captures - backward_function_def = make_function_def( - self._graph, [x.op for x in self._out_grad_placeholders - ] + list(sorted(c.known_ops, key=lambda x: x.name)), + # Excluding input ops from the body as we do not intend to execute these + # operations when the function is executed. + all_ignored_ops = frozenset(x.op for x in all_inputs) + # Enforce a deterministic order of operations in the generated graph. This + # means rerunning the function-defining code will always define the same + # function, which is useful if we serialize this etc. + fdef_ops = tuple(x for x in sorted(c.known_ops, key=lambda x: x.name) + if x not in all_ignored_ops) + bname = _backward_name(self._func_name) + backward_function_def, backward_fn = make_function_def( + bname, self._graph, fdef_ops, all_inputs, backward_outputs) - _register_with_name(_backward_name(self._func_name), backward_function_def) + _register(backward_fn) self._backward_function = GraphModeFunction( - all_inputs, [], backward_function_def, self._graph, c.known_ops, - in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes) + all_inputs, [], backward_function_def, backward_fn, self._graph, + c.known_ops, in_gradients, _map_sequence_obj_to_idx(backward_outputs), + shapes) def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" all_args = args + self._extra_inputs - signature = self._forward_fdef.definition.signature + signature = self._forward_fdef.signature ctx = context.context() if ctx.in_graph_mode(): g = ops.get_default_graph() @@ -325,7 +378,7 @@ class GraphModeFunction(object): return ops.internal_convert_to_tensor(x, ctx=ctx) op = g.create_op( signature.name, [make_tensor(x) for x in all_args], - [dtypes_module.DType(x.type) for x in signature.output_arg], + tuple(dtypes_module.DType(x.type) for x in signature.output_arg), op_def=signature, name="FunctionCall", compute_shapes=False) @@ -361,11 +414,8 @@ class GraphModeFunction(object): if v._trainable: # pylint: disable=protected-access tape.watch_variable(v) - tensor_inputs = [ - x for x in nest.flatten(args) - if isinstance(x, ops.Tensor) - ] - + tensor_inputs = [x for x in nest.flatten(args) + if isinstance(x, ops.Tensor)] if tape.should_record(tensor_inputs) or tape.should_record( self._extra_inputs): if not self._has_backprop: @@ -384,7 +434,7 @@ class GraphModeFunction(object): args = list(tensor_inputs) + self._extra_inputs op = g.create_op( signature.name, [ops.convert_to_tensor(x) for x in args], - [dtypes_module.DType(x.type) for x in signature.output_arg], + tuple(dtypes_module.DType(x.type) for x in signature.output_arg), op_def=signature, name="FunctionCall", compute_shapes=False) @@ -469,29 +519,32 @@ def _defun_internal(name, func, args, kwds): extra_inputs = [] extra_placeholders = [] outputs_list = nest.flatten(func_outputs) - output_shapes = [x.shape for x in outputs_list if x is not None] + output_shapes = tuple(x.shape for x in outputs_list if x is not None) - flat_inputs = [ - x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor) - ] + flat_inputs = [x for x in nest.flatten(func_inputs) + if isinstance(x, ops.Tensor)] all_inputs = flat_inputs + list(extra_placeholders) - + all_ignored_ops = frozenset(x.op for x in all_inputs) func_def_outputs = [x for x in outputs_list if x is not None] - inference_function_def = make_function_def( - tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs) + fname = _inference_name(name) + operations = tuple(x for x in tmp_graph.get_operations() + if x not in all_ignored_ops) + inference_function_def, fn = make_function_def( + fname, tmp_graph, operations, all_inputs, func_def_outputs) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? - _register_with_name(f.name, f.definition) - _register_with_name(_inference_name(name), inference_function_def) + _register(f._c_func) # pylint: disable=protected-access + _register(fn) return GraphModeFunction( all_inputs, extra_inputs, inference_function_def, + fn, tmp_graph, - tmp_graph.get_operations(), + operations, func_outputs, _map_sequence_obj_to_idx(func_def_outputs), output_shapes, @@ -517,10 +570,9 @@ def _cache_key(x): return x -def _register_with_name(name, fdef): - """Registers the function `fdef` with the name `name`.""" - fdef.signature.name = name - context.context().add_function_def(fdef) +def _register(fn): + """Registers the function `fn`.""" + context.context().add_function(fn) # TODO(apassos): better error messages for non-hashable arguments. diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index faf0ac88bc..3da100d800 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -318,7 +318,9 @@ def _graph_callable_internal(func, shape_and_dtypes): placeholder_inputs = flat_inputs+ list(extra_placeholders) func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)] - initializer_function_def = function.make_function_def( + initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access + initializer_function_def, initializer_fn = function.make_function_def( + initialization_name, tmp_graph, initializing_operations, placeholder_inputs, @@ -327,13 +329,13 @@ def _graph_callable_internal(func, shape_and_dtypes): # Also, what about the gradient registry of these functions? Those need to be # addressed as well. for f in tmp_graph._functions.values(): # pylint: disable=protected-access - function._register_with_name(f.name, f.definition) # pylint: disable=protected-access - function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access - initializer_function_def) + function._register(f._c_func) # pylint: disable=protected-access + function._register(initializer_fn) # pylint: disable=protected-access initializer_function = function.GraphModeFunction( placeholder_inputs, extra_inputs, initializer_function_def, + initializer_fn, tmp_graph, initializing_operations, func_outputs, @@ -342,18 +344,20 @@ def _graph_callable_internal(func, shape_and_dtypes): capture_func_def_outputs = [ x for x in captured_outlist if isinstance(x, tf_ops.Tensor)] - captured_function_def = function.make_function_def( + captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access + captured_function_def, capturing_fn = function.make_function_def( + captured_function_name, tmp_graph, capturing_operations, placeholder_inputs, capture_func_def_outputs) - function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access - captured_function_def) + function._register(capturing_fn) # pylint: disable=protected-access captured_function = function.GraphModeFunction( placeholder_inputs, extra_inputs, captured_function_def, + capturing_fn, tmp_graph, capturing_operations, captured_outputs, diff --git a/tensorflow/python/eager/graph_callable_test.py b/tensorflow/python/eager/graph_callable_test.py index 548e16a909..b9e6ca2a93 100644 --- a/tensorflow/python/eager/graph_callable_test.py +++ b/tensorflow/python/eager/graph_callable_test.py @@ -152,7 +152,6 @@ class GraphCallableTest(test.TestCase): self.assertAllEqual(5, f(constant_op.constant(2))) def testNestedFunction(self): - # TensorFlow function (which is what would be used in TensorFlow graph # construction). @function.Defun(dtypes.int32, dtypes.int32) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 2217513966..36daf59647 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -599,11 +599,6 @@ class Tensor(_TensorLike): """ return _eval_using_default_session(self, feed_dict, self.graph, session) - def _dup(self): - ret = copy.copy(self) - ret._id = uid() # pylint: disable=protected-access - return ret - # TODO(agarwal): consider getting rid of this. class _EagerTensorBase(Tensor): @@ -729,9 +724,6 @@ class _EagerTensorBase(Tensor): return new_tensor # pylint: enable=protected-access - def _dup(self): - return self._copy(device_name=self.device) - @property def shape(self): return tensor_shape.TensorShape(self._shape_tuple()) @@ -1794,7 +1786,7 @@ class Operation(object): c_api.SetRequestedDevice( self._graph._c_graph, # pylint: disable=protected-access self._c_op, # pylint: disable=protected-access - _device_string(device)) + compat.as_text(_device_string(device))) else: self._node_def.device = _device_string(device) @@ -2083,7 +2075,7 @@ class Operation(object): def _set_attr(self, attr_name, attr_value): """Private method used to set an attribute in the node_def.""" - if _USE_C_API: + if self._c_op: buf = c_api.TF_NewBufferFromString( compat.as_bytes(attr_value.SerializeToString())) try: @@ -2652,11 +2644,16 @@ class Graph(object): # TODO(skyewm): fold as much of the above as possible into the C # implementation - if _USE_C_API: + if _USE_C_API or self._use_c_api_hack(): self._scoped_c_graph = c_api_util.ScopedTFGraph() else: self._scoped_c_graph = None + # TODO(apassos) remove once the C API is used by default. + def _use_c_api_hack(self): + """Temporary hack; can be overridden to force C API usage.""" + return False + def _convert_stack(self, stack, include_func_start_lineno=False): """Converts a stack extracted using _extract_stack() to a traceback stack. @@ -2985,9 +2982,14 @@ class Graph(object): # Add function to graph # pylint: disable=protected-access if self._c_graph: - assert function._c_func, ( - "Cannot add function created without C API support to graph " - "created with C API support") + # Handle functions created without using the C API. TODO(apassos,skyewm) + # remove this when all functions are generated using the C API by default + # as this will be unnecessary. + if not function._c_func: + with errors.raise_exception_on_not_ok_status() as status: + serialized = function.definition.SerializeToString() + function._c_func = c_api.TF_FunctionImportFunctionDef( + serialized, status) with errors.raise_exception_on_not_ok_status() as status: gradient = function._grad_func._c_func if function._grad_func else None c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient, diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 82b154164e..82750e9e49 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -18,6 +18,7 @@ limitations under the License. %rename("%s") TFE_NewContext; %rename("%s") TFE_DeleteContext; %rename("%s") TFE_ContextListDevices; +%rename("%s") TFE_ContextAddFunction; %rename("%s") TFE_ContextAddFunctionDef; %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; @@ -149,7 +150,7 @@ limitations under the License. } $1 = &temp; $1->resize(PyInt_AsLong($input), nullptr); -} +} // Create new Status object. %typemap(in, numinputs=0) TF_Status *out_status { |