aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-11-10 15:29:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-10 17:02:59 -0800
commit32ed29515c23bb9ff74d4343693d581754192922 (patch)
treeca213f9e44e55f77a768e0bfb154d7fc9cd594c6
parentce239783908dc2db0eed472344fe8d50ce9a6c9c (diff)
Improvement to benchmark.
PiperOrigin-RevId: 175346269
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc66
-rw-r--r--tensorflow/python/eager/tape.py5
2 files changed, 44 insertions, 27 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 77b49be8f8..372a6bb4b7 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -600,11 +600,33 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
return tensorflow::eager::TapeTensor{id, dtype, shape};
}
+std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
+ PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+ if (seq == nullptr) {
+ return {};
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ std::vector<tensorflow::int64> list;
+ list.reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
+ if (EagerTensor_CheckExact(tensor)) {
+ list.push_back(EagerTensor_id(tensor));
+ } else {
+ PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
+ list.push_back(MakeInt(id_field));
+ Py_DECREF(id_field);
+ }
+ }
+ Py_DECREF(seq);
+ return list;
+}
+
void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
PyObject* output_tensors,
- PyObject* input_tensor_ids,
+ PyObject* input_tensors,
PyObject* backward_function) {
- std::vector<tensorflow::int64> input_ids = MakeIntList(input_tensor_ids);
+ std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
std::vector<tensorflow::eager::TapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
@@ -619,9 +641,26 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
}
}
Py_DECREF(seq);
+ char* op_type_str = nullptr;
+ if (PyBytes_Check(op_type)) {
+ op_type_str = PyBytes_AsString(op_type);
+ } else if (PyUnicode_Check(op_type)) {
+#if PY_MAJOR_VERSION >= 3
+ op_type_str = PyUnicode_AsUTF8(op_type);
+#else
+ PyObject* py_str = PyUnicode_AsUTF8String(op_type);
+ if (py_str == nullptr) return;
+ op_type_str = PyBytes_AS_STRING(py_str);
+ Py_DECREF(py_str);
+#endif
+ } else {
+ PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
+ return;
+ }
+
Py_INCREF(backward_function);
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->RecordOperation(
- PyBytes_AsString(op_type), output_info, input_ids, backward_function,
+ op_type_str, output_info, input_ids, backward_function,
[backward_function]() { Py_DECREF(backward_function); });
}
@@ -794,27 +833,6 @@ std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
return list;
}
-std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
- PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
- if (seq == nullptr) {
- return {};
- }
- int len = PySequence_Fast_GET_SIZE(seq);
- std::vector<tensorflow::int64> list;
- list.reserve(len);
- for (int i = 0; i < len; ++i) {
- PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
- if (EagerTensor_CheckExact(tensor)) {
- list.push_back(EagerTensor_id(tensor));
- } else {
- PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
- list.push_back(MakeInt(id_field));
- Py_DECREF(id_field);
- }
- }
- Py_DECREF(seq);
- return list;
-}
PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
PyObject* target, PyObject* sources,
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index a06f5e1a67..afbad183b0 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -23,7 +23,6 @@ import contextlib
import threading
from tensorflow.python import pywrap_tensorflow
-from tensorflow.python.util import compat
def tid(tensor):
@@ -87,9 +86,9 @@ class Tape(object):
"""Records an operation in the tape."""
pywrap_tensorflow.TFE_Py_TapeRecordOperation(
self._tape,
- compat.as_bytes(op_type),
+ op_type,
output_tensors,
- [x._id for x in input_tensors], # pylint: disable=protected-access
+ input_tensors,
backward_function)
def _delete_tensor_id(self, i):