aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-02-21 14:42:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-21 14:48:12 -0800
commit32e39947c80ad410042a5aea266f197e9ecd289d (patch)
tree767b0434f701274d98a689d012eb6bb45363e22e /tensorflow/python
parent24edd50d23416b2c15a4e3509ce079e558500894 (diff)
Record gradient in C
PiperOrigin-RevId: 186522240
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/backprop.py160
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h14
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc232
-rw-r--r--tensorflow/python/pywrap_tfe.i2
4 files changed, 235 insertions, 173 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index d8e13d7231..5505661dbb 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -137,110 +137,6 @@ _gradient_functions_lock = threading.Lock()
_tracing = False
-# TODO(apassos) replace this with a mechanism which can happen at the op
-# gradient function registration site, to be less error-prone
-# TODO(apassos) add ops other than those in nn_grad and math_grad
-_ops_which_dont_need_outputs = set([
- "Identity",
- "MatMul",
- "Conv2DBackpropInput",
- "Conv2DBackpropFilter",
- "Conv3D",
- "Conv3DBackpropInputV2",
- "AvgPool3D",
- "AvgPool3DGrad",
- "MaxPool3D",
- "MaxPool3DGrad",
- "MaxPool3DGradGrad",
- "BiasAdd",
- "BiasAddV1",
- "BiasAddGrad",
- "Relu6",
- "Softplus",
- "SoftplusGrad",
- "Softsign",
- "ReluGrad",
- "Conv2D",
- "DepthwiseConv2dNative",
- "Dilation2D",
- "AvgPool",
- "AvgPoolGrad",
- "BatchNormWithGlobalNormalization",
- "L2Loss",
- "Sum",
- "Prod",
- "SegmentSum",
- "SegmentMean",
- "SparseSegmentSum",
- "SparseSegmentMean",
- "SparseSegmentSqrtN",
- "SegmentMin",
- "SegmentMax",
- "UnsortedSegmentSum",
- "UnsortedSegmentMax",
- "Abs",
- "Neg",
- "ReciprocalGrad",
- "Square",
- "Expm1",
- "Log",
- "Log1p",
- "TanhGrad",
- "SigmoidGrad",
- "Sign",
- "Sin",
- "Cos",
- "Tan",
- "Add",
- "Sub",
- "Mul",
- "Div",
- "RealDiv",
- "Maximum",
- "Minimum",
- "SquaredDifference",
- "Select",
- "SparseMatMul",
- "BatchMatMul",
- "Complex",
- "Real",
- "Imag",
- "Angle",
- "Conj",
- "Cast",
- "Cross",
- "Cumsum",
- "Cumprod",
- "ReadVariableOp",
- "VarHandleOp",
- "Shape",
-])
-
-_ops_which_dont_need_inputs = set([
- "Identity",
- "Softmax",
- "LogSoftmax",
- "BiasAdd",
- "Relu",
- "Elu",
- "Selu",
- "SparseSoftmaxCrossEntropyWithLogits",
- "Neg",
- "Inv",
- "Reciprocal",
- "Sqrt",
- "Exp",
- "Tanh",
- "Sigmoid",
- "Real",
- "Imag",
- "Conj",
- "ReadVariableOp",
- "VarHandleOp",
- "Shape",
-])
-
-
# TODO(agarwal): use an automatic mechanism for handling None arguments to
# gradient functions.
# Some gradient functions can accept None arguments for gradients. The following
@@ -259,57 +155,25 @@ _grad_fn_accepts_none_for_indices = {
}
-def _record_gradient(op_name, inputs, attrs, results, name):
- """Records gradients for a TensorFlow operation.
-
- Args:
- op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
- execute.
- inputs: A flat list of Tensor object inputs to the operation.
- attrs: A tuple with alternating string attr names and attr values for this
- operation.
- results: The results of the operation (as a flat list).
- name: Customized name for the operation.
-
- Returns:
- A list of maybe-wrapped results. Either Tensors or TensorNodes.
-
- Raises:
- An exception on error.
- """
- if not tape.could_possibly_record():
- return
-
- if op_name in _ops_which_dont_need_outputs:
- op_outputs = None
- else:
- # TODO(apassos) this line creates a weak circular reference where the
- # backprop function keeps an output alive which in turn keeps the tape entry
- # alive which keeps the backprop function alive. Figure out how to break
- # this up without breaking second derivatives of ops like Exp whose
- # gradients depend only on the outputs.
- op_outputs = results
-
- if op_name in _ops_which_dont_need_inputs:
- op_inputs = None
- else:
- op_inputs = inputs
-
- num_inputs = len(inputs)
+def _get_backward_fn(op_name, attrs, num_inputs, op_inputs, op_outputs):
def grad_fn(*orig_outputs):
- """Generated gradient function."""
result = _magic_gradient_function(op_name, attrs, num_inputs,
op_inputs, op_outputs, orig_outputs)
if _tracing:
- print("Gradient for", (name if name else op_name), "inputs", op_inputs,
- "output_grads", orig_outputs, "gradients", result)
+ print("Gradient for", op_name, "inputs", op_inputs, "output_grads",
+ orig_outputs, "gradients", result)
return nest.flatten(result)
- tape.record_operation(op_name, results, inputs, grad_fn)
- if _tracing:
- print("Computed op", (name if name else op_name), "inputs", inputs,
- "outputs", results)
+ return grad_fn
+
+
+pywrap_tensorflow.TFE_Py_RegisterBackwardFunctionGetter(_get_backward_fn)
+
+
+def _record_gradient(op_name, inputs, attrs, results, name):
+ return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs,
+ results, name)
execute.record_gradient = _record_gradient
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 16b7d1a119..f9692a8910 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -59,6 +59,15 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
// This function is not thread-safe.
PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e);
+// Registers e as the backward_function_getter.
+// The registered function creates a backward function (a function that can
+// return the gradient of the inputs an op given the gradient of it's outputs).
+// The registered function will be passed the following arguments:
+// op_name, attrs, num_inputs, op_inputs, op_outputs
+//
+// This function is not thread-safe.
+PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e);
+
// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
// `exception` if not nullptr, else using the class registered via
// TFE_Py_RegisterExceptionClass), and returns -1.
@@ -165,6 +174,11 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
// directive.
PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args);
+// Record the gradient for a given op.
+PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
+ PyObject* attrs, PyObject* results,
+ PyObject* name);
+
// Returns the set of variables watched by the given tape.
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index cabbcc48fd..30e08c8e65 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/mutex.h"
@@ -575,6 +576,9 @@ PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
// Python subclass of Exception that is created to signal fallback.
PyObject* fallback_exception_class = nullptr;
+// Python function that returns a backward_function.
+PyObject* backward_function_getter = nullptr;
+
tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
@@ -647,6 +651,23 @@ PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
}
}
+PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e) {
+ if (backward_function_getter != nullptr) {
+ Py_DECREF(backward_function_getter);
+ }
+ if (!PyCallable_Check(e)) {
+ backward_function_getter = nullptr;
+ PyErr_SetString(PyExc_TypeError,
+ "TFE_Py_RegisterBackwardFunctionGetter: "
+ "Registered object should be function.");
+ return nullptr;
+ } else {
+ Py_INCREF(e);
+ backward_function_getter = e;
+ Py_RETURN_NONE;
+ }
+}
+
void RaiseFallbackException(const char* message) {
if (fallback_exception_class != nullptr) {
PyErr_SetObject(fallback_exception_class, Py_BuildValue("s", message));
@@ -1062,16 +1083,10 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
return result;
}
-void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
- PyObject* input_tensors,
- PyObject* backward_function) {
- if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) {
- return;
- }
- std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
- if (PyErr_Occurred()) {
- return;
- }
+namespace {
+void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
+ const std::vector<tensorflow::int64>& input_ids,
+ PyObject* backward_function) {
std::vector<tensorflow::eager::TapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
@@ -1110,6 +1125,19 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
[backward_function]() { Py_DECREF(backward_function); });
}
}
+} // namespace
+
+void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
+ PyObject* input_tensors,
+ PyObject* backward_function) {
+ if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) {
+ return;
+ }
+ std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
+ if (PyErr_Occurred()) return;
+
+ TapeSetRecordOperation(op_type, output_tensors, input_ids, backward_function);
+}
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
for (TFE_Py_Tape* tape : SafeTapeSet()) {
@@ -1430,6 +1458,164 @@ bool RaiseIfNotPyList(PyObject* list, const string& attr_name) {
return true;
}
+bool OpDoesntRequireOutput(const string& op_name) {
+ static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_outputs =
+ new tensorflow::gtl::FlatSet<string>({
+ "Identity",
+ "MatMul",
+ "Conv2DBackpropInput",
+ "Conv2DBackpropFilter",
+ "Conv3D",
+ "Conv3DBackpropInputV2",
+ "AvgPool3D",
+ "AvgPool3DGrad",
+ "MaxPool3D",
+ "MaxPool3DGrad",
+ "MaxPool3DGradGrad",
+ "BiasAdd",
+ "BiasAddV1",
+ "BiasAddGrad",
+ "Relu6",
+ "Softplus",
+ "SoftplusGrad",
+ "Softsign",
+ "ReluGrad",
+ "Conv2D",
+ "DepthwiseConv2dNative",
+ "Dilation2D",
+ "AvgPool",
+ "AvgPoolGrad",
+ "BatchNormWithGlobalNormalization",
+ "L2Loss",
+ "Sum",
+ "Prod",
+ "SegmentSum",
+ "SegmentMean",
+ "SparseSegmentSum",
+ "SparseSegmentMean",
+ "SparseSegmentSqrtN",
+ "SegmentMin",
+ "SegmentMax",
+ "UnsortedSegmentSum",
+ "UnsortedSegmentMax",
+ "Abs",
+ "Neg",
+ "ReciprocalGrad",
+ "Square",
+ "Expm1",
+ "Log",
+ "Log1p",
+ "TanhGrad",
+ "SigmoidGrad",
+ "Sign",
+ "Sin",
+ "Cos",
+ "Tan",
+ "Add",
+ "Sub",
+ "Mul",
+ "Div",
+ "RealDiv",
+ "Maximum",
+ "Minimum",
+ "SquaredDifference",
+ "Select",
+ "SparseMatMul",
+ "BatchMatMul",
+ "Complex",
+ "Real",
+ "Imag",
+ "Angle",
+ "Conj",
+ "Cast",
+ "Cross",
+ "Cumsum",
+ "Cumprod",
+ "ReadVariableOp",
+ "VarHandleOp",
+ "Shape",
+ });
+
+ return ops_that_dont_require_outputs->find(op_name) !=
+ ops_that_dont_require_outputs->end();
+}
+
+bool OpDoesntRequireInput(const string& op_name) {
+ static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_inputs =
+ new tensorflow::gtl::FlatSet<string>({
+ "Identity",
+ "Softmax",
+ "LogSoftmax",
+ "BiasAdd",
+ "Relu",
+ "Elu",
+ "Selu",
+ "SparseSoftmaxCrossEntropyWithLogits",
+ "Neg",
+ "Inv",
+ "Reciprocal",
+ "Sqrt",
+ "Exp",
+ "Tanh",
+ "Sigmoid",
+ "Real",
+ "Imag",
+ "Conj",
+ "ReadVariableOp",
+ "VarHandleOp",
+ "Shape",
+ });
+
+ return ops_that_dont_require_inputs->find(op_name) !=
+ ops_that_dont_require_inputs->end();
+}
+
+PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
+ PyObject* results, PyObject* name) {
+ std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
+ if (PyErr_Occurred()) return nullptr;
+
+ bool should_record = false;
+ for (TFE_Py_Tape* tape : SafeTapeSet()) {
+ if (tape->tape->ShouldRecord(input_ids)) {
+ should_record = true;
+ break;
+ }
+ }
+
+ if (!should_record) Py_RETURN_NONE;
+
+ string c_op_name = TFE_GetPythonString(op_name);
+ PyObject* op_outputs;
+ if (OpDoesntRequireOutput(c_op_name)) {
+ op_outputs = Py_None;
+ } else {
+ op_outputs = results;
+ }
+
+ PyObject* op_inputs;
+ if (OpDoesntRequireInput(c_op_name)) {
+ op_inputs = Py_None;
+ } else {
+ op_inputs = inputs;
+ }
+
+ PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
+ PyObject* callback_args =
+ Py_BuildValue("OOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs);
+
+ PyObject* backward_function =
+ PyObject_CallObject(backward_function_getter, callback_args);
+ Py_DECREF(callback_args);
+ if (backward_function == nullptr) return nullptr;
+
+ TapeSetRecordOperation(op_name, results, input_ids, backward_function);
+
+ Py_DECREF(backward_function);
+
+ Py_RETURN_NONE;
+}
+
bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks,
const tensorflow::OpDef* op_def, PyObject* args,
const std::vector<PyObject*>& flattened_inputs,
@@ -1471,21 +1657,7 @@ bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks,
});
if (run_gradient_callback) {
- if (!PyCallable_Check(record_gradient_callback)) {
- PyErr_SetString(PyExc_TypeError,
- Printf("expected a function for "
- "record_gradient_callback, got %s instead",
- record_gradient_callback->ob_type->tp_name)
- .c_str());
- return false;
- }
-
- PyObject* callback_result =
- PyObject_CallObject(record_gradient_callback, callback_args);
- if (!callback_result) {
- return false;
- }
- Py_DECREF(callback_result);
+ RecordGradient(op_name, inputs, attrs, flattened_result, name);
}
if (run_post_exec_callbacks) {
@@ -1796,3 +1968,13 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
Py_DECREF(flat_result);
return result;
}
+
+PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
+ PyObject* attrs, PyObject* results,
+ PyObject* name) {
+ if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
+ Py_RETURN_NONE;
+ }
+
+ return RecordGradient(op_name, inputs, attrs, results, name);
+}
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 50f481d29e..7ab0db5268 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -29,9 +29,11 @@ limitations under the License.
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
%rename("%s") TFE_Py_RegisterExceptionClass;
+%rename("%s") TFE_Py_RegisterBackwardFunctionGetter;
%rename("%s") TFE_Py_RegisterFallbackExceptionClass;
%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_Py_FastPathExecute;
+%rename("%s") TFE_Py_RecordGradient;
%rename("%s") TFE_Py_UID;
%rename("%s") TFE_Py_TapeSetNew;
%rename("%s") TFE_Py_TapeSetRemove;