diff options
author | Alexandre Passos <apassos@google.com> | 2017-11-10 15:29:35 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-10 17:02:59 -0800 |
commit | 32ed29515c23bb9ff74d4343693d581754192922 (patch) | |
tree | ca213f9e44e55f77a768e0bfb154d7fc9cd594c6 | |
parent | ce239783908dc2db0eed472344fe8d50ce9a6c9c (diff) |
Improvement to benchmark.
PiperOrigin-RevId: 175346269
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 66 | ||||
-rw-r--r-- | tensorflow/python/eager/tape.py | 5 |
2 files changed, 44 insertions, 27 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 77b49be8f8..372a6bb4b7 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -600,11 +600,33 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { return tensorflow::eager::TapeTensor{id, dtype, shape}; } +std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) { + PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); + if (seq == nullptr) { + return {}; + } + int len = PySequence_Fast_GET_SIZE(seq); + std::vector<tensorflow::int64> list; + list.reserve(len); + for (int i = 0; i < len; ++i) { + PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i); + if (EagerTensor_CheckExact(tensor)) { + list.push_back(EagerTensor_id(tensor)); + } else { + PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); + list.push_back(MakeInt(id_field)); + Py_DECREF(id_field); + } + } + Py_DECREF(seq); + return list; +} + void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type, PyObject* output_tensors, - PyObject* input_tensor_ids, + PyObject* input_tensors, PyObject* backward_function) { - std::vector<tensorflow::int64> input_ids = MakeIntList(input_tensor_ids); + std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors); std::vector<tensorflow::eager::TapeTensor> output_info; PyObject* seq = PySequence_Fast(output_tensors, "expected a sequence of integer tensor ids"); @@ -619,9 +641,26 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type, } } Py_DECREF(seq); + char* op_type_str = nullptr; + if (PyBytes_Check(op_type)) { + op_type_str = PyBytes_AsString(op_type); + } else if (PyUnicode_Check(op_type)) { +#if PY_MAJOR_VERSION >= 3 + op_type_str = PyUnicode_AsUTF8(op_type); +#else + PyObject* py_str = PyUnicode_AsUTF8String(op_type); + if (py_str == nullptr) return; + op_type_str = PyBytes_AS_STRING(py_str); + Py_DECREF(py_str); +#endif + } else { + PyErr_SetString(PyExc_RuntimeError, "op_type should be a string."); + return; + } + Py_INCREF(backward_function); reinterpret_cast<TFE_Py_Tape*>(tape)->tape->RecordOperation( - PyBytes_AsString(op_type), output_info, input_ids, backward_function, + op_type_str, output_info, input_ids, backward_function, [backward_function]() { Py_DECREF(backward_function); }); } @@ -794,27 +833,6 @@ std::vector<PyObject*> MakeTensorList(PyObject* tensors) { return list; } -std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) { - PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); - if (seq == nullptr) { - return {}; - } - int len = PySequence_Fast_GET_SIZE(seq); - std::vector<tensorflow::int64> list; - list.reserve(len); - for (int i = 0; i < len; ++i) { - PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i); - if (EagerTensor_CheckExact(tensor)) { - list.push_back(EagerTensor_id(tensor)); - } else { - PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); - list.push_back(MakeInt(id_field)); - Py_DECREF(id_field); - } - } - Py_DECREF(seq); - return list; -} PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, PyObject* target, PyObject* sources, diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index a06f5e1a67..afbad183b0 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -23,7 +23,6 @@ import contextlib import threading from tensorflow.python import pywrap_tensorflow -from tensorflow.python.util import compat def tid(tensor): @@ -87,9 +86,9 @@ class Tape(object): """Records an operation in the tape.""" pywrap_tensorflow.TFE_Py_TapeRecordOperation( self._tape, - compat.as_bytes(op_type), + op_type, output_tensors, - [x._id for x in input_tensors], # pylint: disable=protected-access + input_tensors, backward_function) def _delete_tensor_id(self, i): |