diff options
-rw-r--r-- | tensorflow/python/eager/backprop.py | 2 | ||||
-rw-r--r-- | tensorflow/python/eager/context.py | 15 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 144 | ||||
-rw-r--r-- | tensorflow/python/eager/graph_callable.py | 18 | ||||
-rw-r--r-- | tensorflow/python/eager/graph_callable_test.py | 1 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 30 | ||||
-rw-r--r-- | tensorflow/python/pywrap_tfe.i | 3 |
7 files changed, 70 insertions, 143 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index dc1142705a..0144f3b1e5 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -540,7 +540,7 @@ def _ensure_unique_tensor_objects(parameter_positions, args): if i in parameter_positions: tid = ops.tensor_id(t) if tid in s: - args[i] = gen_array_ops.identity(args[i]) + args[i] = args[i]._dup() # pylint: disable=protected-access else: s.add(tid) return args diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 415416cfae..92f4e15c05 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -288,21 +288,6 @@ class Context(object): self._initialize_handle_and_devices() return self._num_gpus - def add_function(self, fn): - """Add a function definition to the context. - - Once added, the function (identified by its name) can be executed like any - other operation. - - Args: - fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). - """ - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextAddFunction( - self._handle, # pylint: disable=protected-access - fn, - status) - def add_function_def(self, fdef): """Add a function definition to the context. diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 092b36ff20..2f4b59e938 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -25,19 +25,15 @@ import threading import numpy as np -from tensorflow.core.framework import function_pb2 -from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import tape from tensorflow.python.eager.graph_only_ops import graph_placeholder -from tensorflow.python.framework import c_api_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module -from tensorflow.python.framework import errors +from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops from tensorflow.python.ops import gradients_impl -from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -51,41 +47,10 @@ _scoped_captures = threading.local() _scoped_captures.tensors = None -def make_function_def(name, graph, operations, inputs, outputs): - """Makes FunctionDef proto and defined function. - - Args: - name: the function name - graph: the graph from which to build the function - operations: the operations in the function body - inputs: tensors to be used as function arguments - outputs: tensors to be returned from the function - - Returns: - fdef: a FunctionDef protocol buffer for the function - fn: a wrapped TF_Function for the function - """ - with errors.raise_exception_on_not_ok_status() as status: - fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( - graph._c_graph, # pylint: disable=protected-access - compat.as_text(name), - False, - [o._c_op for o in operations], # pylint: disable=protected-access - [t._as_tf_output() for t in inputs], # pylint: disable=protected-access - [t._as_tf_output() for t in outputs], # pylint: disable=protected-access - [compat.as_text("%s" % i) for i in range(len(outputs))], - None, - compat.as_text(""), - status) - # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature, - # but also in general it's nice not to depend on it. - with c_api_util.tf_buffer() as buffer_: - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) - proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) - fdef = function_pb2.FunctionDef() - fdef.ParseFromString(compat.as_bytes(proto_data)) - return fdef, fn +def make_function_def(graph, operations, inputs, outputs): + """Makes function def from the given graph with the operations.""" + return graph_to_function_def.graph_to_function_def( + graph, operations, inputs, outputs) @contextlib.contextmanager @@ -150,10 +115,6 @@ class CapturingGraph(ops.Graph): # for resource tensors. self._last_op_using_resource_tensor = {} - # TODO(apassos) remove once the C API is used by default. - def _use_c_api_hack(self): - return True - def clear_resource_control_flow_state(self): self._last_op_using_resource_tensor = {} @@ -246,20 +207,14 @@ def _inference_name(n): return "__inference_%s_%s" % (n, ops.uid()) -# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction -# so it doesn't have the definition-generating logic and is just a container for -# an already-defined function. class _DefinedFunction(object): """Mocks the interface of tf _DefinedFunction.""" - def __init__(self, fdef, fn): + def __init__(self, fdef): self.definition = fdef self.name = fdef.signature.name - self.signature = fdef.signature self.grad_func_name = None self.python_grad_func = None - self._c_func = fn - self._grad_func = None def _map_sequence_obj_to_idx(sequence): @@ -295,7 +250,6 @@ class GraphModeFunction(object): input_placeholders, extra_inputs, fdef, - fn, graph, operations, func_outputs, @@ -309,7 +263,7 @@ class GraphModeFunction(object): self._graph = graph self._has_backprop = False self._func_name = fdef.signature.name - self._fdef = _DefinedFunction(fdef, fn) + self._fdef = _DefinedFunction(fdef) self._num_outputs = len(fdef.signature.output_arg) self._ops = operations self._func_outputs = func_outputs @@ -329,45 +283,38 @@ class GraphModeFunction(object): with self._graph.as_default(), context.graph_mode(): c = _CapturingContext() with c: - filtered_outputs = [x for x in self._returns if x is not None] + filtered_outputs = [ + x for x in self._returns if x is not None + ] self._out_grad_placeholders = [ - graph_placeholder(x.dtype, x.shape) for x in filtered_outputs] + graph_placeholder(x.dtype, x.shape) for x in filtered_outputs + ] in_gradients = gradients_impl.gradients( filtered_outputs, self._input_placeholders, grad_ys=self._out_grad_placeholders) - shapes = tuple(x.shape for x in in_gradients if x is not None) + shapes = [x.shape for x in in_gradients if x is not None] captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) - forward_name = _forward_name(self._func_name) - forward_function_def, forward_fn = make_function_def( - forward_name, self._graph, self._ops, self._input_placeholders, + forward_function_def = make_function_def( + self._graph, self._ops, self._input_placeholders, filtered_outputs + captures) - self._forward_fdef = _DefinedFunction(forward_function_def, forward_fn) - _register(forward_fn) - backward_outputs = tuple(x for x in in_gradients if x is not None) + self._forward_fdef = _DefinedFunction(forward_function_def) + _register_with_name(_forward_name(self._func_name), forward_function_def) + backward_outputs = [x for x in in_gradients if x is not None] all_inputs = self._out_grad_placeholders + captures - # Excluding input ops from the body as we do not intend to execute these - # operations when the function is executed. - all_ignored_ops = frozenset(x.op for x in all_inputs) - # Enforce a deterministic order of operations in the generated graph. This - # means rerunning the function-defining code will always define the same - # function, which is useful if we serialize this etc. - fdef_ops = tuple(x for x in sorted(c.known_ops, key=lambda x: x.name) - if x not in all_ignored_ops) - bname = _backward_name(self._func_name) - backward_function_def, backward_fn = make_function_def( - bname, self._graph, fdef_ops, + backward_function_def = make_function_def( + self._graph, [x.op for x in self._out_grad_placeholders + ] + list(sorted(c.known_ops, key=lambda x: x.name)), all_inputs, backward_outputs) - _register(backward_fn) + _register_with_name(_backward_name(self._func_name), backward_function_def) self._backward_function = GraphModeFunction( - all_inputs, [], backward_function_def, backward_fn, self._graph, - c.known_ops, in_gradients, _map_sequence_obj_to_idx(backward_outputs), - shapes) + all_inputs, [], backward_function_def, self._graph, c.known_ops, + in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes) def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" all_args = args + self._extra_inputs - signature = self._forward_fdef.signature + signature = self._forward_fdef.definition.signature ctx = context.context() if ctx.in_graph_mode(): g = ops.get_default_graph() @@ -378,7 +325,7 @@ class GraphModeFunction(object): return ops.internal_convert_to_tensor(x, ctx=ctx) op = g.create_op( signature.name, [make_tensor(x) for x in all_args], - tuple(dtypes_module.DType(x.type) for x in signature.output_arg), + [dtypes_module.DType(x.type) for x in signature.output_arg], op_def=signature, name="FunctionCall", compute_shapes=False) @@ -414,8 +361,11 @@ class GraphModeFunction(object): if v._trainable: # pylint: disable=protected-access tape.watch_variable(v) - tensor_inputs = [x for x in nest.flatten(args) - if isinstance(x, ops.Tensor)] + tensor_inputs = [ + x for x in nest.flatten(args) + if isinstance(x, ops.Tensor) + ] + if tape.should_record(tensor_inputs) or tape.should_record( self._extra_inputs): if not self._has_backprop: @@ -434,7 +384,7 @@ class GraphModeFunction(object): args = list(tensor_inputs) + self._extra_inputs op = g.create_op( signature.name, [ops.convert_to_tensor(x) for x in args], - tuple(dtypes_module.DType(x.type) for x in signature.output_arg), + [dtypes_module.DType(x.type) for x in signature.output_arg], op_def=signature, name="FunctionCall", compute_shapes=False) @@ -519,32 +469,29 @@ def _defun_internal(name, func, args, kwds): extra_inputs = [] extra_placeholders = [] outputs_list = nest.flatten(func_outputs) - output_shapes = tuple(x.shape for x in outputs_list if x is not None) + output_shapes = [x.shape for x in outputs_list if x is not None] - flat_inputs = [x for x in nest.flatten(func_inputs) - if isinstance(x, ops.Tensor)] + flat_inputs = [ + x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor) + ] all_inputs = flat_inputs + list(extra_placeholders) - all_ignored_ops = frozenset(x.op for x in all_inputs) + func_def_outputs = [x for x in outputs_list if x is not None] - fname = _inference_name(name) - operations = tuple(x for x in tmp_graph.get_operations() - if x not in all_ignored_ops) - inference_function_def, fn = make_function_def( - fname, tmp_graph, operations, all_inputs, func_def_outputs) + inference_function_def = make_function_def( + tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? - _register(f._c_func) # pylint: disable=protected-access - _register(fn) + _register_with_name(f.name, f.definition) + _register_with_name(_inference_name(name), inference_function_def) return GraphModeFunction( all_inputs, extra_inputs, inference_function_def, - fn, tmp_graph, - operations, + tmp_graph.get_operations(), func_outputs, _map_sequence_obj_to_idx(func_def_outputs), output_shapes, @@ -570,9 +517,10 @@ def _cache_key(x): return x -def _register(fn): - """Registers the function `fn`.""" - context.context().add_function(fn) +def _register_with_name(name, fdef): + """Registers the function `fdef` with the name `name`.""" + fdef.signature.name = name + context.context().add_function_def(fdef) # TODO(apassos): better error messages for non-hashable arguments. diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index 3da100d800..faf0ac88bc 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -318,9 +318,7 @@ def _graph_callable_internal(func, shape_and_dtypes): placeholder_inputs = flat_inputs+ list(extra_placeholders) func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)] - initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access - initializer_function_def, initializer_fn = function.make_function_def( - initialization_name, + initializer_function_def = function.make_function_def( tmp_graph, initializing_operations, placeholder_inputs, @@ -329,13 +327,13 @@ def _graph_callable_internal(func, shape_and_dtypes): # Also, what about the gradient registry of these functions? Those need to be # addressed as well. for f in tmp_graph._functions.values(): # pylint: disable=protected-access - function._register(f._c_func) # pylint: disable=protected-access - function._register(initializer_fn) # pylint: disable=protected-access + function._register_with_name(f.name, f.definition) # pylint: disable=protected-access + function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access + initializer_function_def) initializer_function = function.GraphModeFunction( placeholder_inputs, extra_inputs, initializer_function_def, - initializer_fn, tmp_graph, initializing_operations, func_outputs, @@ -344,20 +342,18 @@ def _graph_callable_internal(func, shape_and_dtypes): capture_func_def_outputs = [ x for x in captured_outlist if isinstance(x, tf_ops.Tensor)] - captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access - captured_function_def, capturing_fn = function.make_function_def( - captured_function_name, + captured_function_def = function.make_function_def( tmp_graph, capturing_operations, placeholder_inputs, capture_func_def_outputs) - function._register(capturing_fn) # pylint: disable=protected-access + function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access + captured_function_def) captured_function = function.GraphModeFunction( placeholder_inputs, extra_inputs, captured_function_def, - capturing_fn, tmp_graph, capturing_operations, captured_outputs, diff --git a/tensorflow/python/eager/graph_callable_test.py b/tensorflow/python/eager/graph_callable_test.py index b9e6ca2a93..548e16a909 100644 --- a/tensorflow/python/eager/graph_callable_test.py +++ b/tensorflow/python/eager/graph_callable_test.py @@ -152,6 +152,7 @@ class GraphCallableTest(test.TestCase): self.assertAllEqual(5, f(constant_op.constant(2))) def testNestedFunction(self): + # TensorFlow function (which is what would be used in TensorFlow graph # construction). @function.Defun(dtypes.int32, dtypes.int32) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 36daf59647..2217513966 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -599,6 +599,11 @@ class Tensor(_TensorLike): """ return _eval_using_default_session(self, feed_dict, self.graph, session) + def _dup(self): + ret = copy.copy(self) + ret._id = uid() # pylint: disable=protected-access + return ret + # TODO(agarwal): consider getting rid of this. class _EagerTensorBase(Tensor): @@ -724,6 +729,9 @@ class _EagerTensorBase(Tensor): return new_tensor # pylint: enable=protected-access + def _dup(self): + return self._copy(device_name=self.device) + @property def shape(self): return tensor_shape.TensorShape(self._shape_tuple()) @@ -1786,7 +1794,7 @@ class Operation(object): c_api.SetRequestedDevice( self._graph._c_graph, # pylint: disable=protected-access self._c_op, # pylint: disable=protected-access - compat.as_text(_device_string(device))) + _device_string(device)) else: self._node_def.device = _device_string(device) @@ -2075,7 +2083,7 @@ class Operation(object): def _set_attr(self, attr_name, attr_value): """Private method used to set an attribute in the node_def.""" - if self._c_op: + if _USE_C_API: buf = c_api.TF_NewBufferFromString( compat.as_bytes(attr_value.SerializeToString())) try: @@ -2644,16 +2652,11 @@ class Graph(object): # TODO(skyewm): fold as much of the above as possible into the C # implementation - if _USE_C_API or self._use_c_api_hack(): + if _USE_C_API: self._scoped_c_graph = c_api_util.ScopedTFGraph() else: self._scoped_c_graph = None - # TODO(apassos) remove once the C API is used by default. - def _use_c_api_hack(self): - """Temporary hack; can be overridden to force C API usage.""" - return False - def _convert_stack(self, stack, include_func_start_lineno=False): """Converts a stack extracted using _extract_stack() to a traceback stack. @@ -2982,14 +2985,9 @@ class Graph(object): # Add function to graph # pylint: disable=protected-access if self._c_graph: - # Handle functions created without using the C API. TODO(apassos,skyewm) - # remove this when all functions are generated using the C API by default - # as this will be unnecessary. - if not function._c_func: - with errors.raise_exception_on_not_ok_status() as status: - serialized = function.definition.SerializeToString() - function._c_func = c_api.TF_FunctionImportFunctionDef( - serialized, status) + assert function._c_func, ( + "Cannot add function created without C API support to graph " + "created with C API support") with errors.raise_exception_on_not_ok_status() as status: gradient = function._grad_func._c_func if function._grad_func else None c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient, diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 82750e9e49..82b154164e 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -18,7 +18,6 @@ limitations under the License. %rename("%s") TFE_NewContext; %rename("%s") TFE_DeleteContext; %rename("%s") TFE_ContextListDevices; -%rename("%s") TFE_ContextAddFunction; %rename("%s") TFE_ContextAddFunctionDef; %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; @@ -150,7 +149,7 @@ limitations under the License. } $1 = &temp; $1->resize(PyInt_AsLong($input), nullptr); -} +} // Create new Status object. %typemap(in, numinputs=0) TF_Status *out_status { |