diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-09-19 14:54:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 14:57:27 -0700 |
commit | c3014ec19e23e4aad7286b3fac6b25a5fb4a6326 (patch) | |
tree | 1ffab9c512fe71884a19851b63c1025d487dbe3e /tensorflow/python/eager | |
parent | 4e7d5f008be62bb7ca3e1646af8d4f22287d9e50 (diff) |
Allow the tape tensor to have unknown shapes.
This is done by making the TapeTensor a template rather than a concrete struct.
PiperOrigin-RevId: 213700425
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop.py | 5 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/eager/imperative_grad.py | 5 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 458 |
5 files changed, 286 insertions, 195 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index c1bc27d443..f80256fc2a 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -34,6 +34,7 @@ cc_library( "//tensorflow/python:safe_ptr", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", + "@com_google_absl//absl/types:variant", ], ) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 50a6ce6324..d95e0fe721 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -608,8 +608,9 @@ def _ones(shape, dtype): _default_vspace = imperative_grad.VSpace( num_elements_fn=_num_elements, aggregate_fn=_aggregate_grads, - zeros=_zeros, - ones=_ones) + zeros_fn=_zeros, + ones_fn=_ones, + graph_shape_fn=gen_array_ops.shape) pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace) diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index f938ed5df8..32731747b7 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -1022,6 +1022,18 @@ class BackpropTest(test.TestCase): resource_variable_ops.ResourceVariable(2.0)) self.assertAllEqual(gradients_constants, gradients_variables) + def testUnknownShapes(self): + with context.graph_mode(): + with backprop.GradientTape() as tape: + a = array_ops.placeholder(dtype=dtypes.float32, shape=None) + tape.watch(a) + b = a**3 + + db_da = tape.gradient(b, a) + + with self.cached_session() as sess: + self.assertEqual((8.0, 12.0), sess.run((b, db_da), feed_dict={a: 2.0})) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index 5f027d107c..5f5af4ab6c 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -23,8 +23,9 @@ import collections from tensorflow.python import pywrap_tensorflow -VSpace = collections.namedtuple( - "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"]) +VSpace = collections.namedtuple("VSpace", [ + "aggregate_fn", "num_elements_fn", "zeros_fn", "ones_fn", "graph_shape_fn" +]) def imperative_grad( diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index a0f6be459e..196e20e4d7 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/python/eager/pywrap_tfe.h" +#include "absl/types/variant.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" @@ -889,12 +890,239 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) { return static_cast<tensorflow::DataType>(id); } +class PyTapeTensor { + public: + PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype, + const tensorflow::TensorShape& shape) + : id_(id), dtype_(dtype), shape_(shape) {} + PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype, + PyObject* shape) + : id_(id), dtype_(dtype), shape_(shape) { + Py_INCREF(absl::get<1>(shape_)); + } + PyTapeTensor(const PyTapeTensor& other) { + id_ = other.id_; + dtype_ = other.dtype_; + shape_ = other.shape_; + if (shape_.index() == 1) { + Py_INCREF(absl::get<1>(shape_)); + } + } + + ~PyTapeTensor() { + if (shape_.index() == 1) { + Py_DECREF(absl::get<1>(shape_)); + } + } + PyObject* GetShape() const; + PyObject* GetDType() const { return PyLong_FromLong(dtype_); } + tensorflow::int64 GetID() const { return id_; } + + private: + tensorflow::int64 id_; + tensorflow::DataType dtype_; + absl::variant<tensorflow::TensorShape, PyObject*> shape_; +}; + +class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction, + PyTapeTensor> { + public: + explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { + Py_INCREF(py_vspace_); + } + + tensorflow::Status Initialize() { + num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); + if (num_elements_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); + if (aggregate_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn"); + if (zeros_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn"); + if (ones_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn"); + if (graph_shape_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + return tensorflow::Status::OK(); + } + + ~PyVSpace() override { + Py_XDECREF(num_elements_); + Py_XDECREF(aggregate_fn_); + Py_XDECREF(zeros_fn_); + Py_XDECREF(ones_fn_); + Py_XDECREF(graph_shape_fn_); + + Py_DECREF(py_vspace_); + } + + tensorflow::int64 NumElements(PyObject* tensor) const final { + if (EagerTensor_CheckExact(tensor)) { + return PyEagerTensor_NumElements(tensor); + } + PyObject* arglist = + Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor)); + PyObject* result = PyEval_CallObject(num_elements_, arglist); + Py_DECREF(arglist); + if (result == nullptr) { + // The caller detects whether a python exception has been raised. + return -1; + } + tensorflow::int64 r = MakeInt(result); + Py_DECREF(result); + return r; + } + + PyObject* AggregateGradients( + tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final { + PyObject* list = PyList_New(gradient_tensors.size()); + for (int i = 0; i < gradient_tensors.size(); ++i) { + // Note: stealing a reference to the gradient tensors. + CHECK(gradient_tensors[i] != nullptr); + CHECK(gradient_tensors[i] != Py_None); + PyList_SET_ITEM(list, i, + reinterpret_cast<PyObject*>(gradient_tensors[i])); + } + PyObject* arglist = Py_BuildValue("(O)", list); + CHECK(arglist != nullptr); + PyObject* result = PyEval_CallObject(aggregate_fn_, arglist); + Py_DECREF(arglist); + Py_DECREF(list); + return result; + } + + void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); } + + PyObject* Zeros(const PyTapeTensor& tensor) const final { + PyObject* py_shape = tensor.GetShape(); + PyObject* py_dtype = tensor.GetDType(); + PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); + PyObject* result = PyEval_CallObject(zeros_fn_, arg_list); + Py_DECREF(arg_list); + Py_DECREF(py_dtype); + Py_DECREF(py_shape); + return reinterpret_cast<PyObject*>(result); + } + + PyObject* Ones(const PyTapeTensor& tensor) const final { + PyObject* py_shape = tensor.GetShape(); + PyObject* py_dtype = tensor.GetDType(); + PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); + PyObject* result = PyEval_CallObject(ones_fn_, arg_list); + Py_DECREF(arg_list); + Py_DECREF(py_dtype); + Py_DECREF(py_shape); + return result; + } + + PyObject* GraphShape(PyObject* tensor) const { + PyObject* arg_list = Py_BuildValue("(O)", tensor); + PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list); + Py_DECREF(arg_list); + return result; + } + + tensorflow::Status CallBackwardFunction( + PyBackwardFunction* backward_function, + tensorflow::gtl::ArraySlice<PyObject*> output_gradients, + std::vector<PyObject*>* result) const final { + PyObject* grads = PyTuple_New(output_gradients.size()); + for (int i = 0; i < output_gradients.size(); ++i) { + if (output_gradients[i] == nullptr) { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(grads, i, Py_None); + } else { + PyTuple_SET_ITEM(grads, i, + reinterpret_cast<PyObject*>(output_gradients[i])); + } + } + PyObject* py_result = (*backward_function)(grads); + Py_DECREF(grads); + if (py_result == nullptr) { + return tensorflow::errors::Internal("gradient function threw exceptions"); + } + result->clear(); + PyObject* seq = + PySequence_Fast(py_result, "expected a sequence of gradients"); + if (seq == nullptr) { + return tensorflow::errors::InvalidArgument( + "gradient function did not return a list"); + } + int len = PySequence_Fast_GET_SIZE(seq); + VLOG(1) << "Gradient length is " << len; + result->reserve(len); + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(seq, i); + if (item == Py_None) { + result->push_back(nullptr); + } else { + Py_INCREF(item); + result->push_back(item); + } + } + Py_DECREF(seq); + Py_DECREF(py_result); + return tensorflow::Status::OK(); + } + + void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } + + private: + PyObject* py_vspace_; + + PyObject* num_elements_; + PyObject* aggregate_fn_; + PyObject* zeros_fn_; + PyObject* ones_fn_; + PyObject* graph_shape_fn_; +}; +PyVSpace* py_vspace = nullptr; + +PyObject* TFE_Py_RegisterVSpace(PyObject* e) { + if (py_vspace != nullptr) { + delete py_vspace; + } + + py_vspace = new PyVSpace(e); + auto status = py_vspace->Initialize(); + if (MaybeRaiseExceptionFromStatus(status, nullptr)) { + delete py_vspace; + return nullptr; + } + + Py_RETURN_NONE; +} + +PyObject* PyTapeTensor::GetShape() const { + if (shape_.index() == 0) { + auto& shape = absl::get<0>(shape_); + PyObject* py_shape = PyTuple_New(shape.dims()); + for (int i = 0; i < shape.dims(); ++i) { + PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); + } + + return py_shape; + } + + return py_vspace->GraphShape(absl::get<1>(shape_)); +} + class GradientTape - : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> { + : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, + PyTapeTensor> { public: explicit GradientTape(bool persistent, bool watch_accessed_variables) - : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>( - persistent), + : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, + PyTapeTensor>(persistent), watch_accessed_variables_(watch_accessed_variables) {} virtual ~GradientTape() { @@ -1175,7 +1403,24 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id); } -static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { +bool ListContainsNone(PyObject* list) { + if (list == Py_None) return true; + tensorflow::Safe_PyObjectPtr seq( + PySequence_Fast(list, "expected a sequence")); + if (seq == nullptr) { + return false; + } + + int len = PySequence_Size(list); + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i); + if (item == Py_None) return true; + } + + return false; +} + +static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); tensorflow::int64 id = PyEagerTensor_ID(tensor); @@ -1183,16 +1428,16 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { const tensorflow::Status status = t->handle->Shape(&tensor_shape); if (MaybeRaiseExceptionFromStatus(status, nullptr)) { - return tensorflow::eager::TapeTensor{id, t->handle->dtype, - tensorflow::TensorShape({})}; + return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), + tensorflow::TensorShape({})); } else { - return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape}; + return PyTapeTensor(id, t->handle->dtype, tensor_shape); } } tensorflow::int64 id = FastTensorId(tensor); if (PyErr_Occurred()) { - return tensorflow::eager::TapeTensor{ - id, static_cast<tensorflow::DataType>(0), tensorflow::TensorShape({})}; + return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), + tensorflow::TensorShape({})); } PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype"); PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum"); @@ -1200,16 +1445,21 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { tensorflow::DataType dtype = static_cast<tensorflow::DataType>(MakeInt(dtype_enum)); Py_DECREF(dtype_enum); - if (PyErr_Occurred() != nullptr) { - return tensorflow::eager::TapeTensor{id, dtype, - tensorflow::TensorShape({})}; + if (PyErr_Occurred()) { + return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), + tensorflow::TensorShape({})); } static char _shape_tuple[] = "_shape_tuple"; PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr); - if (PyErr_Occurred() != nullptr) { - return tensorflow::eager::TapeTensor{id, dtype, - tensorflow::TensorShape({})}; + if (PyErr_Occurred()) { + return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), + tensorflow::TensorShape({})); } + + if (ListContainsNone(shape_tuple)) { + return PyTapeTensor(id, dtype, tensor); + } + auto l = MakeIntList(shape_tuple); Py_DECREF(shape_tuple); // Replace -1, which represents accidental Nones which can occur in graph mode @@ -1220,7 +1470,7 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { } } tensorflow::TensorShape shape(l); - return tensorflow::eager::TapeTensor{id, dtype, shape}; + return PyTapeTensor(id, dtype, shape); } std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) { @@ -1286,7 +1536,7 @@ void TapeSetRecordOperation( const std::vector<tensorflow::DataType>& input_dtypes, const std::function<PyBackwardFunction*()>& backward_function_getter, const std::function<void(PyBackwardFunction*)>& backward_function_killer) { - std::vector<tensorflow::eager::TapeTensor> output_info; + std::vector<PyTapeTensor> output_info; PyObject* seq = PySequence_Fast(output_tensors, "expected a sequence of integer tensor ids"); int len = PySequence_Size(output_tensors); @@ -1362,180 +1612,6 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { } } -class PyVSpace - : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> { - public: - explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { - Py_INCREF(py_vspace_); - } - - tensorflow::Status Initialize() { - num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); - if (num_elements_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); - if (aggregate_fn_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - zeros_ = PyObject_GetAttrString(py_vspace_, "zeros"); - if (zeros_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - ones_ = - PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones"); - if (ones_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - return tensorflow::Status::OK(); - } - - ~PyVSpace() override { - Py_XDECREF(num_elements_); - Py_XDECREF(aggregate_fn_); - Py_XDECREF(zeros_); - Py_XDECREF(ones_); - - Py_DECREF(py_vspace_); - } - - tensorflow::int64 NumElements(PyObject* tensor) const final { - if (EagerTensor_CheckExact(tensor)) { - return PyEagerTensor_NumElements(tensor); - } - PyObject* arglist = - Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor)); - PyObject* result = PyEval_CallObject(num_elements_, arglist); - Py_DECREF(arglist); - if (result == nullptr) { - // The caller detects whether a python exception has been raised. - return -1; - } - tensorflow::int64 r = MakeInt(result); - Py_DECREF(result); - return r; - } - - PyObject* AggregateGradients( - tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final { - PyObject* list = PyList_New(gradient_tensors.size()); - for (int i = 0; i < gradient_tensors.size(); ++i) { - // Note: stealing a reference to the gradient tensors. - CHECK(gradient_tensors[i] != nullptr); - CHECK(gradient_tensors[i] != Py_None); - PyList_SET_ITEM(list, i, - reinterpret_cast<PyObject*>(gradient_tensors[i])); - } - PyObject* arglist = Py_BuildValue("(O)", list); - CHECK(arglist != nullptr); - PyObject* result = PyEval_CallObject(aggregate_fn_, arglist); - Py_DECREF(arglist); - Py_DECREF(list); - return result; - } - - void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); } - - PyObject* Zeros(tensorflow::TensorShape shape, - tensorflow::DataType dtype) const final { - PyObject* py_shape = PyTuple_New(shape.dims()); - for (int i = 0; i < shape.dims(); ++i) { - PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); - } - PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype)); - PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); - PyObject* result = PyEval_CallObject(zeros_, arg_list); - Py_DECREF(arg_list); - Py_DECREF(py_dtype); - Py_DECREF(py_shape); - return reinterpret_cast<PyObject*>(result); - } - - PyObject* Ones(tensorflow::TensorShape shape, - tensorflow::DataType dtype) const final { - PyObject* py_shape = PyTuple_New(shape.dims()); - for (int i = 0; i < shape.dims(); ++i) { - PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); - } - PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype)); - PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); - PyObject* result = PyEval_CallObject(ones_, arg_list); - Py_DECREF(arg_list); - Py_DECREF(py_dtype); - Py_DECREF(py_shape); - return result; - } - - tensorflow::Status CallBackwardFunction( - PyBackwardFunction* backward_function, - tensorflow::gtl::ArraySlice<PyObject*> output_gradients, - std::vector<PyObject*>* result) const final { - PyObject* grads = PyTuple_New(output_gradients.size()); - for (int i = 0; i < output_gradients.size(); ++i) { - if (output_gradients[i] == nullptr) { - Py_INCREF(Py_None); - PyTuple_SET_ITEM(grads, i, Py_None); - } else { - PyTuple_SET_ITEM(grads, i, - reinterpret_cast<PyObject*>(output_gradients[i])); - } - } - PyObject* py_result = (*backward_function)(grads); - Py_DECREF(grads); - if (py_result == nullptr) { - return tensorflow::errors::Internal("gradient function threw exceptions"); - } - result->clear(); - PyObject* seq = - PySequence_Fast(py_result, "expected a sequence of gradients"); - if (seq == nullptr) { - return tensorflow::errors::InvalidArgument( - "gradient function did not return a list"); - } - int len = PySequence_Fast_GET_SIZE(seq); - VLOG(1) << "Gradient length is " << len; - result->reserve(len); - for (int i = 0; i < len; ++i) { - PyObject* item = PySequence_Fast_GET_ITEM(seq, i); - if (item == Py_None) { - result->push_back(nullptr); - } else { - Py_INCREF(item); - result->push_back(item); - } - } - Py_DECREF(seq); - Py_DECREF(py_result); - return tensorflow::Status::OK(); - } - - void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } - - private: - PyObject* py_vspace_; - - PyObject* num_elements_; - PyObject* aggregate_fn_; - PyObject* zeros_; - PyObject* ones_; -}; -PyVSpace* py_vspace = nullptr; - -PyObject* TFE_Py_RegisterVSpace(PyObject* e) { - if (py_vspace != nullptr) { - delete py_vspace; - } - - py_vspace = new PyVSpace(e); - auto status = py_vspace->Initialize(); - if (MaybeRaiseExceptionFromStatus(status, nullptr)) { - delete py_vspace; - return nullptr; - } - - Py_RETURN_NONE; -} - std::vector<PyObject*> MakeTensorList(PyObject* tensors) { PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); if (seq == nullptr) { |