aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/c_api.cc4
-rw-r--r--tensorflow/python/eager/benchmarks_test.py29
-rw-r--r--tensorflow/python/eager/core.py14
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h20
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc630
-rw-r--r--tensorflow/python/eager/pywrap_tfe_test.py114
-rw-r--r--tensorflow/python/pywrap_tfe.i1
7 files changed, 690 insertions, 122 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 3c7f041b39..b10af0f060 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -109,6 +109,10 @@ TF_Status* TF_NewStatus() { return new TF_Status; }
void TF_DeleteStatus(TF_Status* s) { delete s; }
void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
+ if (code == TF_OK) {
+ s->status = Status::OK();
+ return;
+ }
s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
}
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 75526ba9c1..7a60bd68e4 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -28,11 +28,14 @@ from __future__ import print_function
import time
import numpy as np
+import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop # pylint: disable=unused-import
from tensorflow.python.eager import context
+from tensorflow.python.eager import core
+from tensorflow.python.eager import execute
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
@@ -46,19 +49,25 @@ CPU = "/device:CPU:0"
GPU = "/device:GPU:0"
-def record_gradient_callback(inputs, attrs, results):
- return backprop._record_gradient("MatMul", inputs, attrs, results, None)
-
-
-def c_tfe_py_fastpath_execute(a, b, transpose_a=False, transpose_b=False):
+def c_tfe_py_fastpath_execute(a,
+ b,
+ transpose_a=False,
+ transpose_b=False,
+ name=None):
ctx = context.context()
assert not ctx.in_graph_mode(
), "The prototype doesn't contain C code for graph construction"
- ctx_handle = ctx._handle # pylint: disable=protected-access
-
- return pywrap_tensorflow.TFE_Py_FastPathExecute(
- ctx_handle, None, "MatMul", record_gradient_callback, a, b,
- "transpose_a", transpose_a, "transpose_b", transpose_b)[0]
+ try:
+ return pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "MatMul", execute.record_gradient, name,
+ ctx._post_execution_callbacks, a, b, "transpose_a", transpose_a,
+ "transpose_b", transpose_b)
+ except core._NotOkStatusException as e:
+ if name is not None:
+ message = e.message + " name: " + name
+ else:
+ message = e.message
+ six.raise_from(core._status_to_exception(e.code, message), None)
class MicroBenchmarks(test.Benchmark):
diff --git a/tensorflow/python/eager/core.py b/tensorflow/python/eager/core.py
index 483b717210..8fb6930020 100644
--- a/tensorflow/python/eager/core.py
+++ b/tensorflow/python/eager/core.py
@@ -47,3 +47,17 @@ class _NotOkStatusException(Exception):
pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException)
+
+
+class _FallbackException(Exception):
+ """Exception class to handle fallback from the fastpath.
+
+ The fastpath that we refer to here is the one implemented to reduce per-op
+ overheads (TFE_Py_FastPathExecute_C). If the conditions for executing the op
+ on the fastpath are not met, we fallback to a safer (and more complete)
+ slowpath, and this Exception is raised to signal that transition.
+ """
+ pass
+
+
+pywrap_tensorflow.TFE_Py_RegisterFallbackExceptionClass(_FallbackException)
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 4aea134fa9..925eeb553c 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -45,9 +45,17 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
PyObject* attrs, TFE_OutputTensorHandles* outputs,
TF_Status* out_status);
+// Registers e as the Exception to be raised when the conditions of
+// TFE_Py_FastPathExecute_C have not been met. When this exception is set, it
+// is a signal to the calling code that it should fall back to the safer (and
+// more complete) code path.
+PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
+
// Registers e as the Exception class for handling not ok Status. Returns
// Py_None if registration succeeds, else throws a TypeError and returns NULL.
-PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
+//
+// This function is not thread-safe.
+PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e);
// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
// `exception` if not nullptr, else using the class registered via
@@ -142,10 +150,12 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
// or NULL for automatic selection.
// Item 3: op_name: Name of the TensorFlow op to execute.
// Item 4: record_gradient_callback: Callback that records the gradient of the
-// result.
-// The callback takes (inputs, attrs, result) - all sequences and
-// records the gradient.
-// Item 5 onwards: inputs - This is a list of inputs followed by a list of
+// result. The callback takes (op_name, inputs, attrs, result, name)
+// - all sequences and records the gradient.
+// Item 5: name: An optional name for the operation.
+// Item 6: List representing all callbacks to execute after successful
+// op execute.
+// Item 7 onwards: inputs - This is a list of inputs followed by a list of
// attrs. It is not necessary for type attrs to be present.
//
// This is named _C since there doesn't seem to be any way to make it visible
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index d927f3abed..6e96cf6ebe 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -23,9 +23,11 @@ limitations under the License.
#include "tensorflow/c/eager/tape.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/eager/pywrap_tensor.h"
@@ -56,6 +58,7 @@ PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong)
#else
PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
PARSE_VALUE(ParseInt64Value, int64_t, PyInt_Check, PyInt_AsLong)
+PARSE_VALUE(ParseInt64LongValue, int64_t, PyLong_Check, PyLong_AsLong)
#endif
PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
#undef PARSE_VALUE
@@ -86,8 +89,40 @@ bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
return true;
}
-bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list,
- TF_AttrType type, TF_Status* status) {
+bool IsInteger(PyObject* py_value) {
+#if PY_MAJOR_VERSION >= 3
+ return PyLong_Check(py_value);
+#else
+ return PyInt_Check(py_value);
+#endif
+}
+
+// The passed in py_value is expected to be an object of the python type
+// dtypes.DType or an int.
+bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
+ int* value) {
+ if (IsInteger(py_value)) {
+ return ParseIntValue(key, py_value, status, value);
+ }
+
+ PyObject* py_type_enum = PyObject_GetAttrString(py_value, "_type_enum");
+ if (py_type_enum == nullptr) {
+ return false;
+ }
+
+ if (!ParseIntValue(key, py_type_enum, status, value)) {
+ Py_DECREF(py_type_enum);
+ return false;
+ }
+
+ Py_DECREF(py_type_enum);
+ return true;
+}
+
+bool SetOpAttrList(
+ TFE_Op* op, const char* key, PyObject* py_list, TF_AttrType type,
+ tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
+ TF_Status* status) {
if (!PySequence_Check(py_list)) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
@@ -97,6 +132,7 @@ bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list,
return false;
}
const int num_values = PySequence_Size(py_list);
+ if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
#define PARSE_LIST(c_type, parse_fn) \
std::unique_ptr<c_type[]> values(new c_type[num_values]); \
@@ -118,7 +154,7 @@ bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list,
PARSE_LIST(unsigned char, ParseBoolValue);
TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
} else if (type == TF_ATTR_TYPE) {
- PARSE_LIST(int, ParseIntValue);
+ PARSE_LIST(int, ParseTypeValue);
TFE_OpSetAttrTypeList(op, key,
reinterpret_cast<const TF_DataType*>(values.get()),
num_values);
@@ -183,8 +219,65 @@ bool SetOpAttrList(TFE_Op* op, const char* key, PyObject* py_list,
return true;
}
-bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key,
- PyObject* py_value, TF_AttrType type, TF_Status* status) {
+void SetOpAttrListDefault(
+ TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, const char* key,
+ TF_AttrType type,
+ tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
+ TF_Status* status) {
+ if (type == TF_ATTR_STRING) {
+ int num_values = attr.default_value().list().s_size();
+ std::unique_ptr<const char* []> values(new const char*[num_values]);
+ (*attr_list_sizes)[key] = num_values;
+ for (int i = 0; i < num_values; i++) {
+ values[i] = attr.default_value().list().s(i).data();
+ }
+ TFE_OpSetAttrStringList(op, key, values.get(), num_values);
+ } else if (type == TF_ATTR_INT) {
+ int num_values = attr.default_value().list().i_size();
+ std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
+ (*attr_list_sizes)[key] = num_values;
+ for (int i = 0; i < num_values; i++) {
+ values[i] = attr.default_value().list().i(i);
+ }
+ TFE_OpSetAttrIntList(op, key, values.get(), num_values);
+ } else if (type == TF_ATTR_FLOAT) {
+ int num_values = attr.default_value().list().f_size();
+ std::unique_ptr<float[]> values(new float[num_values]);
+ (*attr_list_sizes)[key] = num_values;
+ for (int i = 0; i < num_values; i++) {
+ values[i] = attr.default_value().list().f(i);
+ }
+ TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
+ } else if (type == TF_ATTR_BOOL) {
+ int num_values = attr.default_value().list().b_size();
+ std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
+ (*attr_list_sizes)[key] = num_values;
+ for (int i = 0; i < num_values; i++) {
+ values[i] = attr.default_value().list().b(i);
+ }
+ TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
+ } else if (type == TF_ATTR_TYPE) {
+ int num_values = attr.default_value().list().type_size();
+ std::unique_ptr<int[]> values(new int[num_values]);
+ (*attr_list_sizes)[key] = num_values;
+ for (int i = 0; i < num_values; i++) {
+ values[i] = attr.default_value().list().type(i);
+ }
+ TFE_OpSetAttrTypeList(op, key,
+ reinterpret_cast<const TF_DataType*>(values.get()),
+ attr.default_value().list().type_size());
+ } else {
+ TF_SetStatus(status, TF_UNIMPLEMENTED,
+ "Shapes, tensors and lists are not yet implemented for "
+ "default valued attributes for an operation.");
+ }
+}
+
+bool SetOpAttrScalar(
+ TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value,
+ TF_AttrType type,
+ tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
+ TF_Status* status) {
if (type == TF_ATTR_STRING) {
const char* value;
if (!ParseStringValue(key, py_value, status, &value)) return false;
@@ -193,6 +286,10 @@ bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key,
int64_t value;
if (!ParseInt64Value(key, py_value, status, &value)) return false;
TFE_OpSetAttrInt(op, key, value);
+ // attr_list_sizes is set for all int attributes (since at this point we are
+ // not aware if that attribute might be used to calculate the size of an
+ // output list or not).
+ if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
} else if (type == TF_ATTR_FLOAT) {
float value;
if (!ParseFloatValue(key, py_value, status, &value)) return false;
@@ -203,7 +300,7 @@ bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key,
TFE_OpSetAttrBool(op, key, value);
} else if (type == TF_ATTR_TYPE) {
int value;
- if (!ParseIntValue(key, py_value, status, &value)) return false;
+ if (!ParseTypeValue(key, py_value, status, &value)) return false;
TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
} else if (type == TF_ATTR_SHAPE) {
if (py_value == Py_None) {
@@ -269,6 +366,40 @@ bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key,
return true;
}
+void SetOpAttrScalarDefault(
+ TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, const char* attr_name,
+ TF_AttrType type,
+ tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
+ TF_Status* status) {
+ if (type == TF_ATTR_STRING) {
+ if (attr.default_value().value_case() == tensorflow::AttrValue::kS) {
+ TFE_OpSetAttrString(op, attr_name, attr.default_value().s().data());
+ }
+ } else if (type == TF_ATTR_INT) {
+ if (attr.default_value().value_case() == tensorflow::AttrValue::kI) {
+ TFE_OpSetAttrInt(op, attr_name,
+ static_cast<int64_t>(attr.default_value().i()));
+ (*attr_list_sizes)[attr.name()] = attr.default_value().i();
+ }
+ } else if (type == TF_ATTR_FLOAT) {
+ if (attr.default_value().value_case() == tensorflow::AttrValue::kF) {
+ TFE_OpSetAttrFloat(op, attr_name, attr.default_value().f());
+ }
+ } else if (type == TF_ATTR_BOOL) {
+ if (attr.default_value().value_case() == tensorflow::AttrValue::kB) {
+ TFE_OpSetAttrBool(op, attr_name, attr.default_value().b());
+ }
+ } else if (type == TF_ATTR_TYPE) {
+ if (attr.default_value().value_case() == tensorflow::AttrValue::kType) {
+ TFE_OpSetAttrType(op, attr_name,
+ static_cast<TF_DataType>(attr.default_value().type()));
+ }
+ } else {
+ TF_SetStatus(status, TF_UNIMPLEMENTED,
+ "Shapes, tensors and lists are not yet implemented.");
+ }
+}
+
// start_index is the index at which the Tuple/List attrs will start getting
// processed.
void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
@@ -294,9 +425,38 @@ void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
if (TF_GetCode(out_status) != TF_OK) return;
if (is_list != 0) {
- if (!SetOpAttrList(op, key, py_value, type, out_status)) return;
+ if (!SetOpAttrList(op, key, py_value, type, nullptr, out_status)) return;
+ } else {
+ if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
+ return;
+ }
+ }
+}
+
+// This function will set the op attrs required. If an attr has the value of
+// None, then it will read the AttrDef to get the default value and set that
+// instead. Any failure in this function will simply fall back to the slow path.
+void SetOpAttrWithDefaults(
+ TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
+ const char* attr_name, PyObject* attr_value,
+ tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
+ TF_Status* status) {
+ unsigned char is_list = 0;
+ const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
+ if (TF_GetCode(status) != TF_OK) return;
+ if (attr_value == Py_None) {
+ if (is_list != 0) {
+ SetOpAttrListDefault(op, attr, attr_name, type, attr_list_sizes, status);
} else {
- if (!SetOpAttrScalar(ctx, op, key, py_value, type, out_status)) return;
+ SetOpAttrScalarDefault(op, attr, attr_name, type, attr_list_sizes,
+ status);
+ }
+ } else {
+ if (is_list != 0) {
+ SetOpAttrList(op, attr_name, attr_value, type, attr_list_sizes, status);
+ } else {
+ SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
+ status);
}
}
}
@@ -305,6 +465,9 @@ void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
+// Python subclass of Exception that is created to signal fallback.
+PyObject* fallback_exception_class = nullptr;
+
tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
@@ -360,6 +523,37 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
}
}
+PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
+ if (fallback_exception_class != nullptr) {
+ Py_DECREF(fallback_exception_class);
+ }
+ if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
+ fallback_exception_class = nullptr;
+ PyErr_SetString(PyExc_TypeError,
+ "TFE_Py_RegisterFallbackExceptionClass: "
+ "Registered class should be subclass of Exception.");
+ return nullptr;
+ } else {
+ Py_INCREF(e);
+ fallback_exception_class = e;
+ Py_RETURN_NONE;
+ }
+}
+
+void RaiseFallbackException(const char* message) {
+ if (fallback_exception_class != nullptr) {
+ PyErr_SetObject(fallback_exception_class, Py_BuildValue("s", message));
+ return;
+ }
+
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "Fallback exception type not set, attempting to fallback due to ",
+ message)
+ .data());
+}
+
int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
if (TF_GetCode(status) == TF_OK) return 0;
const char* msg = TF_Message(status);
@@ -1036,17 +1230,62 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
}
namespace {
-static const int kFastPathExecuteInputStartIndex = 4;
+static const int kFastPathExecuteInputStartIndex = 6;
+
+PyObject* GetPythonObjectFromString(const char* s) {
+#if PY_MAJOR_VERSION >= 3
+ return PyUnicode_FromString(s);
+#else
+ return PyBytes_FromString(s);
+#endif
+}
-bool CheckEagerTensors(PyObject* seq, int start_index, int num_to_check) {
- for (int i = start_index; i < start_index + num_to_check; i++) {
- PyObject* item = PyTuple_GET_ITEM(seq, i);
- if (!EagerTensor_CheckExact(item)) return false;
+bool CheckEagerTensors(PyObject* seq, int start_index,
+ const tensorflow::OpDef& op_def) {
+ for (int i = 0; i < op_def.input_arg_size(); i++) {
+ PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
+ if (!op_def.input_arg(i).number_attr().empty() ||
+ !op_def.input_arg(i).type_list_attr().empty()) {
+ // This item should be a list input.
+ if (!PyList_Check(item)) return false;
+ for (Py_ssize_t j = 0; j < PyList_Size(item); j++) {
+ if (!EagerTensor_CheckExact(PyList_GET_ITEM(item, j))) return false;
+ }
+ } else if (!EagerTensor_CheckExact(item)) {
+ return false;
+ }
}
return true;
}
+// Adds input and type attr to the op, and to the list of flattened
+// inputs/attrs.
+bool AddInputToOp(PyObject* input, const tensorflow::OpDef::ArgDef* input_arg,
+ std::vector<PyObject*>* flattened_attrs,
+ std::vector<PyObject*>* flattened_inputs, TFE_Op* op,
+ TF_Status* status) {
+ TFE_TensorHandle* input_handle = EagerTensor_Handle(input);
+ if (input_arg != nullptr && !input_arg->type_attr().empty()) {
+ auto dtype = TFE_TensorHandleDataType(input_handle);
+ TFE_OpSetAttrType(op, input_arg->type_attr().data(), dtype);
+ if (flattened_attrs != nullptr) {
+ flattened_attrs->push_back(
+ GetPythonObjectFromString(input_arg->type_attr().data()));
+ flattened_attrs->push_back(PyLong_FromLong(dtype));
+ }
+ }
+
+ if (flattened_inputs != nullptr) {
+ flattened_inputs->push_back(input);
+ }
+ TFE_OpAddInput(op, input_handle, status);
+ if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
+ return false;
+ }
+ return true;
+}
+
const tensorflow::OpDef* GetOpDef(PyObject* py_op_name) {
const char* op_name = TFE_GetPythonString(py_op_name);
if (op_name == nullptr) {
@@ -1073,59 +1312,109 @@ const char* GetDeviceName(PyObject* py_device_name) {
return nullptr;
}
-bool MaybeRunRecordGradientCallback(const tensorflow::OpDef* op_def,
- PyObject* args, PyObject* result,
- PyObject* record_gradient_callback) {
- if (*ThreadTapeIsStopped() || GetTapeSet()->empty() ||
- record_gradient_callback == Py_None) {
- return true;
- }
- if (!PyCallable_Check(record_gradient_callback)) {
- PyErr_SetString(
- PyExc_TypeError,
- Printf(
- "expected a function for record_gradient_callback, got %s instead",
- record_gradient_callback->ob_type->tp_name)
- .c_str());
+bool RaiseIfNotPyList(PyObject* list, const string& attr_name) {
+ if (!PyList_Check(list)) {
+ PyErr_SetString(PyExc_TypeError,
+ Printf("expected a list for attr %s, got %s instead",
+ attr_name.data(), list->ob_type->tp_name)
+ .data());
+
return false;
}
+ return true;
+}
- PyObject* inputs = PyTuple_New(op_def->input_arg_size());
- for (int i = 0; i < op_def->input_arg_size(); i++) {
- auto* input = PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
+bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks,
+ const tensorflow::OpDef* op_def, PyObject* args,
+ const std::vector<PyObject*>& flattened_inputs,
+ const std::vector<PyObject*>& flattened_attrs,
+ PyObject* flattened_result, PyObject* op_name, PyObject* name,
+ PyObject* record_gradient_callback, PyObject* callbacks) {
+ PyObject* inputs = PyTuple_New(flattened_inputs.size());
+ for (int i = 0; i < flattened_inputs.size(); i++) {
+ PyObject* input = flattened_inputs[i];
Py_INCREF(input);
PyTuple_SET_ITEM(inputs, i, input);
}
- int args_size = PyTuple_GET_SIZE(args);
- int num_attrs =
- args_size - op_def->input_arg_size() - kFastPathExecuteInputStartIndex;
+ int num_non_inferred_attrs = PyTuple_GET_SIZE(args) -
+ op_def->input_arg_size() -
+ kFastPathExecuteInputStartIndex;
+ int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
PyObject* attrs = PyTuple_New(num_attrs);
- for (int i = 0; i < num_attrs; i++) {
+
+ for (int i = 0; i < num_non_inferred_attrs; i++) {
auto* attr = PyTuple_GET_ITEM(
args, kFastPathExecuteInputStartIndex + op_def->input_arg_size() + i);
Py_INCREF(attr);
PyTuple_SET_ITEM(attrs, i, attr);
}
+ for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
+ // Not INCREFing anything in flattened_attrs as each of those is a new
+ // reference, so allow the attrs tuple to steal the reference.
+ PyTuple_SET_ITEM(attrs, i, flattened_attrs.at(i - num_non_inferred_attrs));
+ }
- PyObject* callback_args = Py_BuildValue("OOO", inputs, attrs, result);
- PyObject_CallObject(record_gradient_callback, callback_args);
+ auto cleaner = tensorflow::gtl::MakeCleanup([inputs, attrs] {
+ Py_DECREF(inputs);
+ Py_DECREF(attrs);
+ });
+
+ if (run_gradient_callback) {
+ if (!PyCallable_Check(record_gradient_callback)) {
+ PyErr_SetString(PyExc_TypeError,
+ Printf("expected a function for "
+ "record_gradient_callback, got %s instead",
+ record_gradient_callback->ob_type->tp_name)
+ .c_str());
+ return false;
+ }
+
+ PyObject* callback_args =
+ Py_BuildValue("OOOOO", op_name, inputs, attrs, flattened_result, name);
+ PyObject* callback_result =
+ PyObject_CallObject(record_gradient_callback, callback_args);
+ Py_DECREF(callback_args);
+ if (!callback_result) {
+ return false;
+ }
+ Py_DECREF(callback_result);
+ }
+
+ if (run_post_exec_callbacks) {
+ // TODO(nareshmodi): update the execution callback interface to match the
+ // record_gradient interface so that this doesn't need to be rebuilt.
+ PyObject* callback_args =
+ Py_BuildValue("OOOOO", op_name, name, attrs, inputs, flattened_result);
+
+ for (Py_ssize_t i = 0; i < PyList_Size(callbacks); i++) {
+ PyObject* callback_fn = PyList_GET_ITEM(callbacks, i);
+ if (!PyCallable_Check(callback_fn)) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ Printf("expected a function for "
+ "post execution callback in index %ld, got %s instead",
+ i, callback_fn->ob_type->tp_name)
+ .c_str());
+ return false;
+ }
+ PyObject* callback_result =
+ PyObject_CallObject(callback_fn, callback_args);
+ if (!callback_result) {
+ Py_DECREF(callback_args);
+ return false;
+ }
+ Py_DECREF(callback_result);
+ }
+ Py_DECREF(callback_args);
+ }
- Py_DECREF(inputs);
- Py_DECREF(callback_args);
- Py_DECREF(attrs);
return true;
}
+
} // namespace
PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
- TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
- PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
- const tensorflow::OpDef* op_def = GetOpDef(PyTuple_GET_ITEM(args, 2));
- if (op_def == nullptr) return nullptr;
- const char* device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
- PyObject* record_gradient_callback = PyTuple_GET_ITEM(args, 3);
-
Py_ssize_t args_size = PyTuple_GET_SIZE(args);
if (args_size < kFastPathExecuteInputStartIndex) {
PyErr_SetString(
@@ -1136,6 +1425,16 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
return nullptr;
}
+ TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
+ PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
+ const char* device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
+ PyObject* op_name = PyTuple_GET_ITEM(args, 2);
+ const tensorflow::OpDef* op_def = GetOpDef(op_name);
+ if (op_def == nullptr) return nullptr;
+ PyObject* record_gradient_callback = PyTuple_GET_ITEM(args, 3);
+ PyObject* name = PyTuple_GET_ITEM(args, 4);
+ PyObject* callbacks = PyTuple_GET_ITEM(args, 5);
+
if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
PyErr_SetString(
PyExc_ValueError,
@@ -1147,13 +1446,10 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
return nullptr;
}
- if (!CheckEagerTensors(args, kFastPathExecuteInputStartIndex,
- op_def->input_arg_size())) {
- // TODO(nareshmodi): Maybe some other way of signalling that this should
- // fall back?
- PyErr_SetString(PyExc_NotImplementedError,
- "This function does not handle the case of the path where "
- "all inputs are not already EagerTensors.");
+ if (!CheckEagerTensors(args, kFastPathExecuteInputStartIndex, *op_def)) {
+ RaiseFallbackException(
+ "This function does not handle the case of the path where "
+ "all inputs are not already EagerTensors.");
return nullptr;
}
@@ -1167,62 +1463,236 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
return nullptr;
}
- TFE_OpSetDevice(op, device_name, status);
- if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
- return nullptr;
+ // Mapping of attr name to size - used to calculate the number of values
+ // to be expected by the TFE_Execute run.
+ tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes;
+
+ // Set non-inferred attrs, including setting defaults if the attr is passed in
+ // as None.
+ for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size();
+ i < args_size; i += 2) {
+ PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
+ const tensorflow::StringPiece attr_name(TFE_GetPythonString(py_attr_name));
+ PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
+
+ // Not creating an index since most of the time there are not more than a
+ // few attrs.
+ // TODO(nareshmodi): Maybe include the index as part of the
+ // OpRegistrationData.
+ for (const auto& attr : op_def->attr()) {
+ if (attr_name == attr.name()) {
+ SetOpAttrWithDefaults(ctx, op, attr, attr_name.data(), py_attr_value,
+ &attr_list_sizes, status);
+
+ if (TF_GetCode(status) != TF_OK) {
+ RaiseFallbackException(TF_Message(status));
+ return nullptr;
+ }
+
+ break;
+ }
+ }
}
- // Add non-type attrs.
- SetOpAttrs(ctx, op, args,
- kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
- status);
+ TFE_OpSetDevice(op, device_name, status);
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
return nullptr;
}
- // Add type attrs and inputs.
+ // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
+ // (similar to benchmark_tf_gradient_function_*). Also consider using an
+ // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
+ // point out problems with heap allocs.
+ bool run_gradient_callback = !*ThreadTapeIsStopped() &&
+ !GetTapeSet()->empty() &&
+ record_gradient_callback != Py_None;
+ bool run_post_exec_callbacks =
+ callbacks != Py_None && PyList_Size(callbacks) > 0;
+ bool run_callbacks = run_gradient_callback || run_post_exec_callbacks;
+ // Flat attrs and inputs as required by the record_gradient call. The attrs
+ // here only contain inferred attrs (non-inferred attrs are added directly
+ // from the input args).
+ // All items in flattened_attrs contain new references.
+ // All items in flattened_inputs contain borrowed references.
+ // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
+ // directly.
+ std::unique_ptr<std::vector<PyObject*>> flattened_attrs = nullptr;
+ std::unique_ptr<std::vector<PyObject*>> flattened_inputs = nullptr;
+
+ if (run_callbacks) {
+ flattened_attrs.reset(new std::vector<PyObject*>);
+ flattened_inputs.reset(new std::vector<PyObject*>);
+ }
+
+ // Add inferred attrs and inputs.
+ // The following code might set duplicate type attrs. This will result in
+ // the CacheKey for the generated AttrBuilder possibly differing from
+ // those where the type attrs are correctly set. Inconsistent CacheKeys
+ // for ops means that there might be unnecessarily duplicated kernels.
+ // TODO(nareshmodi): Fix this.
for (int i = 0; i < op_def->input_arg_size(); i++) {
const auto& input_arg = op_def->input_arg(i);
PyObject* input =
PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
- TFE_TensorHandle* input_handle = EagerTensor_Handle(input);
-
- // The following code might set duplicate type attrs. This will result in
- // the CacheKey for the generated AttrBuilder possibly differing from those
- // where the type attrs are correctly set. Inconsistent CacheKeys for ops
- // means that there might be unnecessarily duplicated kernels.
- // TODO(nareshmodi): Fix this.
- if (!input_arg.type_attr().empty()) {
- TFE_OpSetAttrType(op, input_arg.type_attr().data(),
- TFE_TensorHandleDataType(input_handle));
+ if (!input_arg.number_attr().empty()) {
+ // The item is a homogeneous list.
+ if (!RaiseIfNotPyList(input, input_arg.number_attr())) return nullptr;
+ Py_ssize_t len = PyList_Size(input);
+
+ TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
+ if (run_callbacks) {
+ flattened_attrs->push_back(
+ GetPythonObjectFromString(input_arg.number_attr().data()));
+ flattened_attrs->push_back(PyLong_FromLong(len));
+ }
+ attr_list_sizes[input_arg.number_attr()] = len;
+
+ if (len > 0) {
+ // First item adds the type attr.
+ if (!AddInputToOp(PyList_GET_ITEM(input, 0), &input_arg,
+ flattened_attrs.get(), flattened_inputs.get(), op,
+ status)) {
+ return nullptr;
+ }
+
+ for (Py_ssize_t j = 1; j < len; j++) {
+ // Since the list is homogeneous, we don't need to re-add the attr.
+ if (!AddInputToOp(PyList_GET_ITEM(input, j), nullptr /* input_arg */,
+ nullptr /* flattened_attrs */,
+ flattened_inputs.get(), op, status)) {
+ return nullptr;
+ }
+ }
+ }
+ } else if (!input_arg.type_list_attr().empty()) {
+ // The item is a heterogeneous list.
+ if (!RaiseIfNotPyList(input, input_arg.type_list_attr())) return nullptr;
+ const string& attr_name = input_arg.type_list_attr();
+ Py_ssize_t len = PyList_Size(input);
+ tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
+ PyObject* py_attr_value = nullptr;
+ if (run_callbacks) {
+ py_attr_value = PyTuple_New(len);
+ }
+ for (Py_ssize_t j = 0; j < len; j++) {
+ PyObject* py_input = PyList_GET_ITEM(input, j);
+ TFE_TensorHandle* input_handle = EagerTensor_Handle(py_input);
+ attr_value[j] = TFE_TensorHandleDataType(input_handle);
+
+ TFE_OpAddInput(op, input_handle, status);
+ if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
+ return nullptr;
+ }
+
+ if (run_callbacks) {
+ flattened_inputs->push_back(py_input);
+
+ PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
+ }
+ }
+ if (run_callbacks) {
+ flattened_attrs->push_back(GetPythonObjectFromString(attr_name.data()));
+ flattened_attrs->push_back(py_attr_value);
+ }
+ TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
+ attr_value.size());
+ attr_list_sizes[attr_name] = len;
+ } else {
+ // The item is a single item.
+ if (!AddInputToOp(input, &input_arg, flattened_attrs.get(),
+ flattened_inputs.get(), op, status)) {
+ return nullptr;
+ }
}
+ }
- TFE_OpAddInput(op, input_handle, status);
- if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
- return nullptr;
+ int num_retvals = 0;
+ for (int i = 0; i < op_def->output_arg_size(); i++) {
+ const auto& output_arg = op_def->output_arg(i);
+ if (!output_arg.number_attr().empty()) {
+ num_retvals += attr_list_sizes[output_arg.number_attr()];
+ } else if (!output_arg.type_list_attr().empty()) {
+ num_retvals += attr_list_sizes[output_arg.type_list_attr()];
+ } else {
+ num_retvals++;
}
}
- int num_retvals = op_def->output_arg_size();
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
Py_BEGIN_ALLOW_THREADS;
TFE_Execute(op, retvals.data(), &num_retvals, status);
Py_END_ALLOW_THREADS;
- if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
+ if (TF_GetCode(status) != TF_OK) {
+ // Augment the status with the op_name for easier debugging similar to
+ // TFE_Py_Execute.
+ TF_SetStatus(status, TF_GetCode(status),
+ tensorflow::strings::StrCat(TF_Message(status), " [Op:",
+ TFE_GetPythonString(op_name), "]")
+ .c_str());
+
+ MaybeRaiseExceptionFromTFStatus(status, nullptr);
return nullptr;
}
- PyObject* result = PyTuple_New(num_retvals);
+ PyObject* flat_result = PyList_New(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
- PyTuple_SET_ITEM(result, i, EagerTensorFromHandle(retvals[i]));
+ PyList_SET_ITEM(flat_result, i, EagerTensorFromHandle(retvals[i]));
}
- if (!MaybeRunRecordGradientCallback(op_def, args, result,
- record_gradient_callback)) {
+ if (run_callbacks &&
+ !RunCallbacks(run_gradient_callback, run_post_exec_callbacks, op_def,
+ args, *flattened_inputs, *flattened_attrs, flat_result,
+ op_name, name, record_gradient_callback, callbacks)) {
return nullptr;
}
+ // Unflatten results.
+ if (op_def->output_arg_size() == 0) {
+ Py_RETURN_NONE;
+ }
+
+ if (op_def->output_arg_size() == 1) {
+ if (!op_def->output_arg(0).number_attr().empty() ||
+ !op_def->output_arg(0).type_list_attr().empty()) {
+ return flat_result;
+ } else {
+ auto* result = PyList_GET_ITEM(flat_result, 0);
+ Py_INCREF(result);
+ Py_DECREF(flat_result);
+ return result;
+ }
+ }
+
+ // Correctly output the results that are made into a namedtuple.
+ PyObject* result = PyList_New(op_def->output_arg_size());
+ int flat_result_index = 0;
+ for (int i = 0; i < op_def->output_arg_size(); i++) {
+ if (!op_def->output_arg(i).number_attr().empty()) {
+ int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
+ PyObject* inner_list = PyList_New(list_length);
+ for (int j = 0; j < list_length; j++) {
+ PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++);
+ Py_INCREF(obj);
+ PyList_SET_ITEM(inner_list, j, obj);
+ }
+ PyList_SET_ITEM(result, i, inner_list);
+ } else if (!op_def->output_arg(i).type_list_attr().empty()) {
+ int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
+ PyObject* inner_list = PyList_New(list_length);
+ for (int j = 0; j < list_length; j++) {
+ PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++);
+ Py_INCREF(obj);
+ PyList_SET_ITEM(inner_list, j, obj);
+ }
+ PyList_SET_ITEM(result, i, inner_list);
+ } else {
+ PyObject* obj = PyList_GET_ITEM(flat_result, flat_result_index++);
+ Py_INCREF(obj);
+ PyList_SET_ITEM(result, i, obj);
+ }
+ }
+ Py_DECREF(flat_result);
return result;
}
diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py
index d4f4ed592f..49323e6640 100644
--- a/tensorflow/python/eager/pywrap_tfe_test.py
+++ b/tensorflow/python/eager/pywrap_tfe_test.py
@@ -21,28 +21,15 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import execute
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-def record_gradient_callback(inputs, attrs, results):
- return backprop._record_gradient("MatMul", inputs, attrs, results, None)
-
-
-def c_tfe_py_fastpath_execute(a, b, transpose_a=False, transpose_b=False):
- ctx = context.context()
- assert not ctx.in_graph_mode(
- ), "The prototype doesn't contain C code for graph construction"
- ctx_handle = ctx._handle # pylint: disable=protected-access
-
- return pywrap_tensorflow.TFE_Py_FastPathExecute(
- ctx_handle, ctx.device_name, "MatMul", record_gradient_callback, a, b,
- "transpose_a", transpose_a, "transpose_b", transpose_b)[0]
-
-
class Tests(test.TestCase):
@test_util.assert_no_new_tensors
@@ -54,31 +41,100 @@ class Tests(test.TestCase):
a_100_by_784 = random_ops.random_uniform((100, 784))
b_100_by_784 = random_ops.random_uniform((100, 784))
+ ctx = context.context()
+
self.assertAllClose(
math_ops.matmul(a_2_by_2, b_2_by_2),
- c_tfe_py_fastpath_execute(a_2_by_2, b_2_by_2))
+ pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "MatMul", execute.record_gradient,
+ None, None, a_2_by_2, b_2_by_2, "transpose_a", False, "transpose_b",
+ False))
self.assertAllClose(
math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True),
- c_tfe_py_fastpath_execute(a_100_by_784, b_100_by_784, transpose_b=True))
+ pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "MatMul", execute.record_gradient,
+ None, None, a_100_by_784, b_100_by_784, "transpose_a", False,
+ "transpose_b", True))
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
def testFastpathExecute_TapeWrite(self):
+ ctx = context.context()
with backprop.GradientTape(persistent=True) as tape:
a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
tape.watch(a_2_by_2)
- z = c_tfe_py_fastpath_execute(a_2_by_2, a_2_by_2)
+ z = pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "MatMul", execute.record_gradient, None,
+ None, a_2_by_2, a_2_by_2, "transpose_a", False, "transpose_b", False)
dz_dy = tape.gradient(z, [a_2_by_2])[0]
self.assertAllEqual(dz_dy.numpy(),
constant_op.constant(4.0, shape=[2, 2]).numpy())
+ # Tests homogeneous list op
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
- def testFastpathExecute_MatMulSlowPath(self):
- a_2_by_2 = random_ops.random_uniform((2, 2)).cpu().numpy()
+ def testFastpathExecute_AddNCorrectResponse(self):
+ ctx = context.context()
+ a_2_by_2 = random_ops.random_uniform((2, 2))
+ b_2_by_2 = random_ops.random_uniform((2, 2))
+
+ self.assertAllClose(
+ math_ops.add_n([a_2_by_2, b_2_by_2]),
+ pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "AddN", execute.record_gradient, None,
+ None, [a_2_by_2, b_2_by_2]))
+
+ # Tests homogeneous list op
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastpathExecute_AddNTapeWrite(self):
+ ctx = context.context()
+ a_2_by_2 = random_ops.random_uniform((2, 2))
+ b_2_by_2 = random_ops.random_uniform((2, 2))
- with self.assertRaises(NotImplementedError):
- c_tfe_py_fastpath_execute(a_2_by_2, a_2_by_2)
+ with backprop.GradientTape(persistent=True) as tape:
+ tape.watch(a_2_by_2)
+ tape.watch(b_2_by_2)
+ z1 = pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "AddN", execute.record_gradient, None,
+ None, [a_2_by_2, b_2_by_2])
+ z2 = math_ops.add_n([a_2_by_2, b_2_by_2])
+ dz1_dy = tape.gradient(z1, [a_2_by_2])[0]
+ dz2_dy = tape.gradient(z2, [a_2_by_2])[0]
+ self.assertAllEqual(dz1_dy.numpy(), dz2_dy.numpy())
+
+ # Tests heterogeneous list op
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastpathExecute_IdentityNCorrectResponse(self):
+ ctx = context.context()
+ a_2_by_2 = random_ops.random_uniform((2, 2))
+ b_2_by_2 = random_ops.random_uniform((2, 2))
+
+ self.assertAllClose(
+ array_ops.identity_n([a_2_by_2, b_2_by_2]),
+ pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "IdentityN", execute.record_gradient,
+ None, None, [a_2_by_2, b_2_by_2]))
+
+ # Tests heterogeneous list op
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastpathExecute_IdentityNTapeWrite(self):
+ ctx = context.context()
+ a_2_by_2 = random_ops.random_uniform((2, 2))
+ b_2_by_2 = random_ops.random_uniform((2, 2))
+
+ with backprop.GradientTape(persistent=True) as tape:
+ tape.watch(a_2_by_2)
+ tape.watch(b_2_by_2)
+ z1 = pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "IdentityN", execute.record_gradient,
+ None, None, [a_2_by_2, b_2_by_2])
+ z2 = array_ops.identity_n([a_2_by_2, b_2_by_2])
+ dz1_dy = tape.gradient(z1[0], [a_2_by_2])[0]
+ dz2_dy = tape.gradient(z2[0], [a_2_by_2])[0]
+ self.assertAllEqual(dz1_dy.numpy(), dz2_dy.numpy())
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
@@ -89,20 +145,24 @@ class Tests(test.TestCase):
), "The prototype doesn't contain C code for graph construction"
ctx_handle = ctx._handle # pylint: disable=protected-access
+ # Not enough base params
with self.assertRaisesRegexp(ValueError,
- "at least 4 items in the input tuple"):
+ "at least 6 items in the input tuple"):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
"Identity")
+ # Not enough inputs
with self.assertRaisesRegexp(ValueError,
- "Expected to be at least 5, was 4"):
+ "Expected to be at least 7, was 6"):
pywrap_tensorflow.TFE_Py_FastPathExecute(
- ctx_handle, ctx_handle, "Identity", record_gradient_callback)
+ ctx_handle, ctx_handle, "Identity", backprop._record_gradient, None,
+ [])
+ # Bad type
with self.assertRaisesRegexp(TypeError, "expected a string for op_name"):
pywrap_tensorflow.TFE_Py_FastPathExecute(
- ctx_handle, ctx.device_name, ctx_handle, record_gradient_callback,
- a_2_by_2)
+ ctx_handle, ctx.device_name, ctx_handle, backprop._record_gradient,
+ None, [], a_2_by_2)
if __name__ == "__main__":
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 3f25311a83..50f481d29e 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -29,6 +29,7 @@ limitations under the License.
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
%rename("%s") TFE_Py_RegisterExceptionClass;
+%rename("%s") TFE_Py_RegisterFallbackExceptionClass;
%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_Py_FastPathExecute;
%rename("%s") TFE_Py_UID;