aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/eager/backprop.py2
-rw-r--r--tensorflow/python/eager/context.py15
-rw-r--r--tensorflow/python/eager/function.py144
-rw-r--r--tensorflow/python/eager/graph_callable.py18
-rw-r--r--tensorflow/python/eager/graph_callable_test.py1
-rw-r--r--tensorflow/python/framework/ops.py30
-rw-r--r--tensorflow/python/pywrap_tfe.i3
7 files changed, 143 insertions, 70 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 0144f3b1e5..dc1142705a 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -540,7 +540,7 @@ def _ensure_unique_tensor_objects(parameter_positions, args):
if i in parameter_positions:
tid = ops.tensor_id(t)
if tid in s:
- args[i] = args[i]._dup() # pylint: disable=protected-access
+ args[i] = gen_array_ops.identity(args[i])
else:
s.add(tid)
return args
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 92f4e15c05..415416cfae 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -288,6 +288,21 @@ class Context(object):
self._initialize_handle_and_devices()
return self._num_gpus
+ def add_function(self, fn):
+ """Add a function definition to the context.
+
+ Once added, the function (identified by its name) can be executed like any
+ other operation.
+
+ Args:
+ fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
+ """
+ with errors.raise_exception_on_not_ok_status() as status:
+ pywrap_tensorflow.TFE_ContextAddFunction(
+ self._handle, # pylint: disable=protected-access
+ fn,
+ status)
+
def add_function_def(self, fdef):
"""Add a function definition to the context.
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 2f4b59e938..092b36ff20 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -25,15 +25,19 @@ import threading
import numpy as np
+from tensorflow.core.framework import function_pb2
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
+from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_module
-from tensorflow.python.framework import graph_to_function_def
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -47,10 +51,41 @@ _scoped_captures = threading.local()
_scoped_captures.tensors = None
-def make_function_def(graph, operations, inputs, outputs):
- """Makes function def from the given graph with the operations."""
- return graph_to_function_def.graph_to_function_def(
- graph, operations, inputs, outputs)
+def make_function_def(name, graph, operations, inputs, outputs):
+ """Makes FunctionDef proto and defined function.
+
+ Args:
+ name: the function name
+ graph: the graph from which to build the function
+ operations: the operations in the function body
+ inputs: tensors to be used as function arguments
+ outputs: tensors to be returned from the function
+
+ Returns:
+ fdef: a FunctionDef protocol buffer for the function
+ fn: a wrapped TF_Function for the function
+ """
+ with errors.raise_exception_on_not_ok_status() as status:
+ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
+ graph._c_graph, # pylint: disable=protected-access
+ compat.as_text(name),
+ False,
+ [o._c_op for o in operations], # pylint: disable=protected-access
+ [t._as_tf_output() for t in inputs], # pylint: disable=protected-access
+ [t._as_tf_output() for t in outputs], # pylint: disable=protected-access
+ [compat.as_text("%s" % i) for i in range(len(outputs))],
+ None,
+ compat.as_text(""),
+ status)
+ # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature,
+ # but also in general it's nice not to depend on it.
+ with c_api_util.tf_buffer() as buffer_:
+ with errors.raise_exception_on_not_ok_status() as status:
+ pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
+ proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
+ fdef = function_pb2.FunctionDef()
+ fdef.ParseFromString(compat.as_bytes(proto_data))
+ return fdef, fn
@contextlib.contextmanager
@@ -115,6 +150,10 @@ class CapturingGraph(ops.Graph):
# for resource tensors.
self._last_op_using_resource_tensor = {}
+ # TODO(apassos) remove once the C API is used by default.
+ def _use_c_api_hack(self):
+ return True
+
def clear_resource_control_flow_state(self):
self._last_op_using_resource_tensor = {}
@@ -207,14 +246,20 @@ def _inference_name(n):
return "__inference_%s_%s" % (n, ops.uid())
+# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
+# so it doesn't have the definition-generating logic and is just a container for
+# an already-defined function.
class _DefinedFunction(object):
"""Mocks the interface of tf _DefinedFunction."""
- def __init__(self, fdef):
+ def __init__(self, fdef, fn):
self.definition = fdef
self.name = fdef.signature.name
+ self.signature = fdef.signature
self.grad_func_name = None
self.python_grad_func = None
+ self._c_func = fn
+ self._grad_func = None
def _map_sequence_obj_to_idx(sequence):
@@ -250,6 +295,7 @@ class GraphModeFunction(object):
input_placeholders,
extra_inputs,
fdef,
+ fn,
graph,
operations,
func_outputs,
@@ -263,7 +309,7 @@ class GraphModeFunction(object):
self._graph = graph
self._has_backprop = False
self._func_name = fdef.signature.name
- self._fdef = _DefinedFunction(fdef)
+ self._fdef = _DefinedFunction(fdef, fn)
self._num_outputs = len(fdef.signature.output_arg)
self._ops = operations
self._func_outputs = func_outputs
@@ -283,38 +329,45 @@ class GraphModeFunction(object):
with self._graph.as_default(), context.graph_mode():
c = _CapturingContext()
with c:
- filtered_outputs = [
- x for x in self._returns if x is not None
- ]
+ filtered_outputs = [x for x in self._returns if x is not None]
self._out_grad_placeholders = [
- graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
- ]
+ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs]
in_gradients = gradients_impl.gradients(
filtered_outputs,
self._input_placeholders,
grad_ys=self._out_grad_placeholders)
- shapes = [x.shape for x in in_gradients if x is not None]
+ shapes = tuple(x.shape for x in in_gradients if x is not None)
captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
- forward_function_def = make_function_def(
- self._graph, self._ops, self._input_placeholders,
+ forward_name = _forward_name(self._func_name)
+ forward_function_def, forward_fn = make_function_def(
+ forward_name, self._graph, self._ops, self._input_placeholders,
filtered_outputs + captures)
- self._forward_fdef = _DefinedFunction(forward_function_def)
- _register_with_name(_forward_name(self._func_name), forward_function_def)
- backward_outputs = [x for x in in_gradients if x is not None]
+ self._forward_fdef = _DefinedFunction(forward_function_def, forward_fn)
+ _register(forward_fn)
+ backward_outputs = tuple(x for x in in_gradients if x is not None)
all_inputs = self._out_grad_placeholders + captures
- backward_function_def = make_function_def(
- self._graph, [x.op for x in self._out_grad_placeholders
- ] + list(sorted(c.known_ops, key=lambda x: x.name)),
+ # Excluding input ops from the body as we do not intend to execute these
+ # operations when the function is executed.
+ all_ignored_ops = frozenset(x.op for x in all_inputs)
+ # Enforce a deterministic order of operations in the generated graph. This
+ # means rerunning the function-defining code will always define the same
+ # function, which is useful if we serialize this etc.
+ fdef_ops = tuple(x for x in sorted(c.known_ops, key=lambda x: x.name)
+ if x not in all_ignored_ops)
+ bname = _backward_name(self._func_name)
+ backward_function_def, backward_fn = make_function_def(
+ bname, self._graph, fdef_ops,
all_inputs, backward_outputs)
- _register_with_name(_backward_name(self._func_name), backward_function_def)
+ _register(backward_fn)
self._backward_function = GraphModeFunction(
- all_inputs, [], backward_function_def, self._graph, c.known_ops,
- in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
+ all_inputs, [], backward_function_def, backward_fn, self._graph,
+ c.known_ops, in_gradients, _map_sequence_obj_to_idx(backward_outputs),
+ shapes)
def _backprop_call(self, args):
"""Calls the wrapped function and records the result on a tape."""
all_args = args + self._extra_inputs
- signature = self._forward_fdef.definition.signature
+ signature = self._forward_fdef.signature
ctx = context.context()
if ctx.in_graph_mode():
g = ops.get_default_graph()
@@ -325,7 +378,7 @@ class GraphModeFunction(object):
return ops.internal_convert_to_tensor(x, ctx=ctx)
op = g.create_op(
signature.name, [make_tensor(x) for x in all_args],
- [dtypes_module.DType(x.type) for x in signature.output_arg],
+ tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
op_def=signature,
name="FunctionCall",
compute_shapes=False)
@@ -361,11 +414,8 @@ class GraphModeFunction(object):
if v._trainable: # pylint: disable=protected-access
tape.watch_variable(v)
- tensor_inputs = [
- x for x in nest.flatten(args)
- if isinstance(x, ops.Tensor)
- ]
-
+ tensor_inputs = [x for x in nest.flatten(args)
+ if isinstance(x, ops.Tensor)]
if tape.should_record(tensor_inputs) or tape.should_record(
self._extra_inputs):
if not self._has_backprop:
@@ -384,7 +434,7 @@ class GraphModeFunction(object):
args = list(tensor_inputs) + self._extra_inputs
op = g.create_op(
signature.name, [ops.convert_to_tensor(x) for x in args],
- [dtypes_module.DType(x.type) for x in signature.output_arg],
+ tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
op_def=signature,
name="FunctionCall",
compute_shapes=False)
@@ -469,29 +519,32 @@ def _defun_internal(name, func, args, kwds):
extra_inputs = []
extra_placeholders = []
outputs_list = nest.flatten(func_outputs)
- output_shapes = [x.shape for x in outputs_list if x is not None]
+ output_shapes = tuple(x.shape for x in outputs_list if x is not None)
- flat_inputs = [
- x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
- ]
+ flat_inputs = [x for x in nest.flatten(func_inputs)
+ if isinstance(x, ops.Tensor)]
all_inputs = flat_inputs + list(extra_placeholders)
-
+ all_ignored_ops = frozenset(x.op for x in all_inputs)
func_def_outputs = [x for x in outputs_list if x is not None]
- inference_function_def = make_function_def(
- tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
+ fname = _inference_name(name)
+ operations = tuple(x for x in tmp_graph.get_operations()
+ if x not in all_ignored_ops)
+ inference_function_def, fn = make_function_def(
+ fname, tmp_graph, operations, all_inputs, func_def_outputs)
# Register any other functions defined in the graph
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
- _register_with_name(f.name, f.definition)
- _register_with_name(_inference_name(name), inference_function_def)
+ _register(f._c_func) # pylint: disable=protected-access
+ _register(fn)
return GraphModeFunction(
all_inputs,
extra_inputs,
inference_function_def,
+ fn,
tmp_graph,
- tmp_graph.get_operations(),
+ operations,
func_outputs,
_map_sequence_obj_to_idx(func_def_outputs),
output_shapes,
@@ -517,10 +570,9 @@ def _cache_key(x):
return x
-def _register_with_name(name, fdef):
- """Registers the function `fdef` with the name `name`."""
- fdef.signature.name = name
- context.context().add_function_def(fdef)
+def _register(fn):
+ """Registers the function `fn`."""
+ context.context().add_function(fn)
# TODO(apassos): better error messages for non-hashable arguments.
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index faf0ac88bc..3da100d800 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -318,7 +318,9 @@ def _graph_callable_internal(func, shape_and_dtypes):
placeholder_inputs = flat_inputs+ list(extra_placeholders)
func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
- initializer_function_def = function.make_function_def(
+ initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access
+ initializer_function_def, initializer_fn = function.make_function_def(
+ initialization_name,
tmp_graph,
initializing_operations,
placeholder_inputs,
@@ -327,13 +329,13 @@ def _graph_callable_internal(func, shape_and_dtypes):
# Also, what about the gradient registry of these functions? Those need to be
# addressed as well.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
- function._register_with_name(f.name, f.definition) # pylint: disable=protected-access
- function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access
- initializer_function_def)
+ function._register(f._c_func) # pylint: disable=protected-access
+ function._register(initializer_fn) # pylint: disable=protected-access
initializer_function = function.GraphModeFunction(
placeholder_inputs,
extra_inputs,
initializer_function_def,
+ initializer_fn,
tmp_graph,
initializing_operations,
func_outputs,
@@ -342,18 +344,20 @@ def _graph_callable_internal(func, shape_and_dtypes):
capture_func_def_outputs = [
x for x in captured_outlist if isinstance(x, tf_ops.Tensor)]
- captured_function_def = function.make_function_def(
+ captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access
+ captured_function_def, capturing_fn = function.make_function_def(
+ captured_function_name,
tmp_graph,
capturing_operations,
placeholder_inputs,
capture_func_def_outputs)
- function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access
- captured_function_def)
+ function._register(capturing_fn) # pylint: disable=protected-access
captured_function = function.GraphModeFunction(
placeholder_inputs,
extra_inputs,
captured_function_def,
+ capturing_fn,
tmp_graph,
capturing_operations,
captured_outputs,
diff --git a/tensorflow/python/eager/graph_callable_test.py b/tensorflow/python/eager/graph_callable_test.py
index 548e16a909..b9e6ca2a93 100644
--- a/tensorflow/python/eager/graph_callable_test.py
+++ b/tensorflow/python/eager/graph_callable_test.py
@@ -152,7 +152,6 @@ class GraphCallableTest(test.TestCase):
self.assertAllEqual(5, f(constant_op.constant(2)))
def testNestedFunction(self):
-
# TensorFlow function (which is what would be used in TensorFlow graph
# construction).
@function.Defun(dtypes.int32, dtypes.int32)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 2217513966..36daf59647 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -599,11 +599,6 @@ class Tensor(_TensorLike):
"""
return _eval_using_default_session(self, feed_dict, self.graph, session)
- def _dup(self):
- ret = copy.copy(self)
- ret._id = uid() # pylint: disable=protected-access
- return ret
-
# TODO(agarwal): consider getting rid of this.
class _EagerTensorBase(Tensor):
@@ -729,9 +724,6 @@ class _EagerTensorBase(Tensor):
return new_tensor
# pylint: enable=protected-access
- def _dup(self):
- return self._copy(device_name=self.device)
-
@property
def shape(self):
return tensor_shape.TensorShape(self._shape_tuple())
@@ -1794,7 +1786,7 @@ class Operation(object):
c_api.SetRequestedDevice(
self._graph._c_graph, # pylint: disable=protected-access
self._c_op, # pylint: disable=protected-access
- _device_string(device))
+ compat.as_text(_device_string(device)))
else:
self._node_def.device = _device_string(device)
@@ -2083,7 +2075,7 @@ class Operation(object):
def _set_attr(self, attr_name, attr_value):
"""Private method used to set an attribute in the node_def."""
- if _USE_C_API:
+ if self._c_op:
buf = c_api.TF_NewBufferFromString(
compat.as_bytes(attr_value.SerializeToString()))
try:
@@ -2652,11 +2644,16 @@ class Graph(object):
# TODO(skyewm): fold as much of the above as possible into the C
# implementation
- if _USE_C_API:
+ if _USE_C_API or self._use_c_api_hack():
self._scoped_c_graph = c_api_util.ScopedTFGraph()
else:
self._scoped_c_graph = None
+ # TODO(apassos) remove once the C API is used by default.
+ def _use_c_api_hack(self):
+ """Temporary hack; can be overridden to force C API usage."""
+ return False
+
def _convert_stack(self, stack, include_func_start_lineno=False):
"""Converts a stack extracted using _extract_stack() to a traceback stack.
@@ -2985,9 +2982,14 @@ class Graph(object):
# Add function to graph
# pylint: disable=protected-access
if self._c_graph:
- assert function._c_func, (
- "Cannot add function created without C API support to graph "
- "created with C API support")
+ # Handle functions created without using the C API. TODO(apassos,skyewm)
+ # remove this when all functions are generated using the C API by default
+ # as this will be unnecessary.
+ if not function._c_func:
+ with errors.raise_exception_on_not_ok_status() as status:
+ serialized = function.definition.SerializeToString()
+ function._c_func = c_api.TF_FunctionImportFunctionDef(
+ serialized, status)
with errors.raise_exception_on_not_ok_status() as status:
gradient = function._grad_func._c_func if function._grad_func else None
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient,
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 82b154164e..82750e9e49 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -18,6 +18,7 @@ limitations under the License.
%rename("%s") TFE_NewContext;
%rename("%s") TFE_DeleteContext;
%rename("%s") TFE_ContextListDevices;
+%rename("%s") TFE_ContextAddFunction;
%rename("%s") TFE_ContextAddFunctionDef;
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
@@ -149,7 +150,7 @@ limitations under the License.
}
$1 = &temp;
$1->resize(PyInt_AsLong($input), nullptr);
-}
+}
// Create new Status object.
%typemap(in, numinputs=0) TF_Status *out_status {