aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/pywrap_tfe_src.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/pywrap_tfe_src.cc')
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc312
1 files changed, 196 insertions, 116 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 1ed814258b..9f2f4e06ad 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1403,9 +1403,13 @@ class PyVSpace
PyObject* arglist =
Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
PyObject* result = PyEval_CallObject(num_elements_, arglist);
+ Py_DECREF(arglist);
+ if (result == nullptr) {
+ // The caller detects whether a python exception has been raised.
+ return -1;
+ }
tensorflow::int64 r = MakeInt(result);
Py_DECREF(result);
- Py_DECREF(arglist);
return r;
}
@@ -1740,117 +1744,167 @@ PyObject* MaybeGetDTypeForAttr(const string& attr,
Py_RETURN_NONE;
}
-bool OpDoesntRequireOutput(const string& op_name) {
- static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_outputs =
- new tensorflow::gtl::FlatSet<string>({
- "Identity",
- "MatMul",
- "Conv2DBackpropInput",
- "Conv2DBackpropFilter",
- "Conv3D",
- "Conv3DBackpropInputV2",
- "AvgPool3D",
- "AvgPool3DGrad",
- "MaxPool3D",
- "MaxPool3DGrad",
- "MaxPool3DGradGrad",
- "BiasAdd",
- "BiasAddV1",
- "BiasAddGrad",
- "Softplus",
- "SoftplusGrad",
- "Softsign",
- "ReluGrad",
- "Conv2D",
- "DepthwiseConv2dNative",
- "Dilation2D",
- "AvgPool",
- "AvgPoolGrad",
- "BatchNormWithGlobalNormalization",
- "L2Loss",
- "Sum",
- "Prod",
- "SegmentSum",
- "SegmentMean",
- "SparseSegmentSum",
- "SparseSegmentMean",
- "SparseSegmentSqrtN",
- "SegmentMin",
- "SegmentMax",
- "UnsortedSegmentSum",
- "UnsortedSegmentMax",
- "Abs",
- "Neg",
- "ReciprocalGrad",
- "Square",
- "Expm1",
- "Log",
- "Log1p",
- "TanhGrad",
- "SigmoidGrad",
- "Sign",
- "Sin",
- "Cos",
- "Tan",
- "Add",
- "Sub",
- "Mul",
- "Div",
- "RealDiv",
- "Maximum",
- "Minimum",
- "SquaredDifference",
- "Select",
- "SparseMatMul",
- "BatchMatMul",
- "Complex",
- "Real",
- "Imag",
- "Angle",
- "Conj",
- "Cast",
- "Cross",
- "Cumsum",
- "Cumprod",
- "ReadVariableOp",
- "VarHandleOp",
- "Shape",
- "StridedSlice",
+// Returns a pair where the first value of the pair indicates whether or not all
+// outputs are unused. If the first value is false, the second value is a
+// set that identifies which of the output indices are unused.
+bool OpGradientDoesntRequireOutputIndices(
+ const string& op_name,
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
+ static tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
+ new tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
+ // Ops that don't require any outputs.
+ {"Identity", {true, {}}},
+ {"MatMul", {true, {}}},
+ {"Conv2DBackpropInput", {true, {}}},
+ {"Conv2DBackpropFilter", {true, {}}},
+ {"Conv3D", {true, {}}},
+ {"Conv3DBackpropInputV2", {true, {}}},
+ {"AvgPool3D", {true, {}}},
+ {"AvgPool3DGrad", {true, {}}},
+ {"MaxPool3D", {true, {}}},
+ {"MaxPool3DGrad", {true, {}}},
+ {"MaxPool3DGradGrad", {true, {}}},
+ {"BiasAdd", {true, {}}},
+ {"BiasAddV1", {true, {}}},
+ {"BiasAddGrad", {true, {}}},
+ {"Softplus", {true, {}}},
+ {"SoftplusGrad", {true, {}}},
+ {"Softsign", {true, {}}},
+ {"ReluGrad", {true, {}}},
+ {"Conv2D", {true, {}}},
+ {"DepthwiseConv2dNative", {true, {}}},
+ {"Dilation2D", {true, {}}},
+ {"AvgPool", {true, {}}},
+ {"AvgPoolGrad", {true, {}}},
+ {"BatchNormWithGlobalNormalization", {true, {}}},
+ {"L2Loss", {true, {}}},
+ {"Sum", {true, {}}},
+ {"Prod", {true, {}}},
+ {"SegmentSum", {true, {}}},
+ {"SegmentMean", {true, {}}},
+ {"SparseSegmentSum", {true, {}}},
+ {"SparseSegmentMean", {true, {}}},
+ {"SparseSegmentSqrtN", {true, {}}},
+ {"SegmentMin", {true, {}}},
+ {"SegmentMax", {true, {}}},
+ {"UnsortedSegmentSum", {true, {}}},
+ {"UnsortedSegmentMax", {true, {}}},
+ {"Abs", {true, {}}},
+ {"Neg", {true, {}}},
+ {"ReciprocalGrad", {true, {}}},
+ {"Square", {true, {}}},
+ {"Expm1", {true, {}}},
+ {"Log", {true, {}}},
+ {"Log1p", {true, {}}},
+ {"TanhGrad", {true, {}}},
+ {"SigmoidGrad", {true, {}}},
+ {"Sign", {true, {}}},
+ {"Sin", {true, {}}},
+ {"Cos", {true, {}}},
+ {"Tan", {true, {}}},
+ {"Add", {true, {}}},
+ {"Sub", {true, {}}},
+ {"Mul", {true, {}}},
+ {"Div", {true, {}}},
+ {"RealDiv", {true, {}}},
+ {"Maximum", {true, {}}},
+ {"Minimum", {true, {}}},
+ {"SquaredDifference", {true, {}}},
+ {"Select", {true, {}}},
+ {"SparseMatMul", {true, {}}},
+ {"BatchMatMul", {true, {}}},
+ {"Complex", {true, {}}},
+ {"Real", {true, {}}},
+ {"Imag", {true, {}}},
+ {"Angle", {true, {}}},
+ {"Conj", {true, {}}},
+ {"Cast", {true, {}}},
+ {"Cross", {true, {}}},
+ {"Cumsum", {true, {}}},
+ {"Cumprod", {true, {}}},
+ {"ReadVariableOp", {true, {}}},
+ {"VarHandleOp", {true, {}}},
+ {"Shape", {true, {}}},
+ {"StridedSlice", {true, {}}},
+ {"Fill", {true, {}}},
+
+ // Ops that don't require a subset of outputs.
+ {"FusedBatchNorm", {false, {0, 1, 2}}},
});
- return ops_that_dont_require_outputs->find(op_name) !=
- ops_that_dont_require_outputs->end();
-}
-
-bool OpDoesntRequireInput(const string& op_name) {
- static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_inputs =
- new tensorflow::gtl::FlatSet<string>({
- "Identity",
- "Softmax",
- "LogSoftmax",
- "BiasAdd",
- "Relu",
- "Relu6",
- "Elu",
- "Selu",
- "SparseSoftmaxCrossEntropyWithLogits",
- "Neg",
- "Inv",
- "Reciprocal",
- "Sqrt",
- "Exp",
- "Tanh",
- "Sigmoid",
- "Real",
- "Imag",
- "Conj",
- "ReadVariableOp",
- "VarHandleOp",
- "Shape",
+ auto it = m->find(op_name);
+
+ if (it == m->end()) return false;
+
+ *output = &it->second;
+ return true;
+}
+
+// Returns a pair where the first value of the pair indicates whether or not all
+// inputs are unused. If the first value is false, the second value is a
+// set that identifies which of the input indices are unused.
+bool OpGradientDoesntRequireInputIndices(
+ const string& op_name,
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
+ static tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
+ new tensorflow::gtl::FlatMap<
+ string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
+ // Ops that don't require any inputs.
+ {"Identity", {true, {}}},
+ {"Softmax", {true, {}}},
+ {"LogSoftmax", {true, {}}},
+ {"BiasAdd", {true, {}}},
+ {"Relu", {true, {}}},
+ {"Relu6", {true, {}}},
+ {"Elu", {true, {}}},
+ {"Selu", {true, {}}},
+ {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}},
+ {"Neg", {true, {}}},
+ {"Inv", {true, {}}},
+ {"Reciprocal", {true, {}}},
+ {"Sqrt", {true, {}}},
+ {"Exp", {true, {}}},
+ {"Tanh", {true, {}}},
+ {"Sigmoid", {true, {}}},
+ {"Real", {true, {}}},
+ {"Imag", {true, {}}},
+ {"Conj", {true, {}}},
+ {"ReadVariableOp", {true, {}}},
+ {"VarHandleOp", {true, {}}},
+ {"Shape", {true, {}}},
+ {"Fill", {true, {}}},
+
+ // Ops that don't require a subset of inputs.
+ {"FusedBatchNorm", {false, {2}}},
});
- return ops_that_dont_require_inputs->find(op_name) !=
- ops_that_dont_require_inputs->end();
+ auto it = m->find(op_name);
+
+ if (it == m->end()) return false;
+
+ *output = &it->second;
+ return true;
+}
+
+PyObject* CopySequenceSettingIndicesToNull(
+ PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
+ tensorflow::Safe_PyObjectPtr fast_seq(
+ PySequence_Fast(seq, "unable to allocate"));
+ PyObject* result = PyTuple_New(PySequence_Fast_GET_SIZE(fast_seq.get()));
+ for (int i = 0; i < PySequence_Fast_GET_SIZE(fast_seq.get()); i++) {
+ PyObject* item;
+ if (indices.find(i) != indices.end()) {
+ item = Py_None;
+ } else {
+ item = PySequence_Fast_GET_ITEM(fast_seq.get(), i);
+ }
+ Py_INCREF(item);
+ PyTuple_SET_ITEM(result, i, item);
+ }
+ return result;
}
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
@@ -1870,16 +1924,35 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
if (!should_record) Py_RETURN_NONE;
string c_op_name = TFE_GetPythonString(op_name);
+
PyObject* op_outputs;
- if (OpDoesntRequireOutput(c_op_name)) {
- op_outputs = Py_None;
+ bool op_outputs_tuple_created = false;
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required;
+
+ if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) {
+ if (outputs_not_required->first) {
+ op_outputs = Py_None;
+ } else {
+ op_outputs_tuple_created = true;
+ op_outputs = CopySequenceSettingIndicesToNull(
+ results, outputs_not_required->second);
+ }
} else {
op_outputs = results;
}
PyObject* op_inputs;
- if (OpDoesntRequireInput(c_op_name)) {
- op_inputs = Py_None;
+ bool op_inputs_tuple_created = false;
+ std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required;
+
+ if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) {
+ if (inputs_not_required->first) {
+ op_inputs = Py_None;
+ } else {
+ op_inputs_tuple_created = true;
+ op_inputs =
+ CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second);
+ }
} else {
op_inputs = inputs;
}
@@ -1922,6 +1995,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
});
Py_DECREF(num_inputs);
+ if (op_outputs_tuple_created) Py_DECREF(op_outputs);
+ if (op_inputs_tuple_created) Py_DECREF(op_inputs);
Py_RETURN_NONE;
}
@@ -2492,13 +2567,18 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
int num_retvals = 0;
for (int i = 0; i < op_def->output_arg_size(); i++) {
const auto& output_arg = op_def->output_arg(i);
+ int delta = 1;
if (!output_arg.number_attr().empty()) {
- num_retvals += attr_list_sizes[output_arg.number_attr()];
+ delta = attr_list_sizes[output_arg.number_attr()];
} else if (!output_arg.type_list_attr().empty()) {
- num_retvals += attr_list_sizes[output_arg.type_list_attr()];
- } else {
- num_retvals++;
+ delta = attr_list_sizes[output_arg.type_list_attr()];
+ }
+ if (delta < 0) {
+ RaiseFallbackException(
+ "Attributes suggest that the size of an output list is less than 0");
+ return nullptr;
}
+ num_retvals += delta;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);