aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-09-19 14:54:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 14:57:27 -0700
commitc3014ec19e23e4aad7286b3fac6b25a5fb4a6326 (patch)
tree1ffab9c512fe71884a19851b63c1025d487dbe3e /tensorflow/python/eager
parent4e7d5f008be62bb7ca3e1646af8d4f22287d9e50 (diff)
Allow the tape tensor to have unknown shapes.
This is done by making the TapeTensor a template rather than a concrete struct. PiperOrigin-RevId: 213700425
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/backprop.py5
-rw-r--r--tensorflow/python/eager/backprop_test.py12
-rw-r--r--tensorflow/python/eager/imperative_grad.py5
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc458
5 files changed, 286 insertions, 195 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index c1bc27d443..f80256fc2a 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -34,6 +34,7 @@ cc_library(
"//tensorflow/python:safe_ptr",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/types:variant",
],
)
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 50a6ce6324..d95e0fe721 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -608,8 +608,9 @@ def _ones(shape, dtype):
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
- zeros=_zeros,
- ones=_ones)
+ zeros_fn=_zeros,
+ ones_fn=_ones,
+ graph_shape_fn=gen_array_ops.shape)
pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index f938ed5df8..32731747b7 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -1022,6 +1022,18 @@ class BackpropTest(test.TestCase):
resource_variable_ops.ResourceVariable(2.0))
self.assertAllEqual(gradients_constants, gradients_variables)
+ def testUnknownShapes(self):
+ with context.graph_mode():
+ with backprop.GradientTape() as tape:
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+ tape.watch(a)
+ b = a**3
+
+ db_da = tape.gradient(b, a)
+
+ with self.cached_session() as sess:
+ self.assertEqual((8.0, 12.0), sess.run((b, db_da), feed_dict={a: 2.0}))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 5f027d107c..5f5af4ab6c 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -23,8 +23,9 @@ import collections
from tensorflow.python import pywrap_tensorflow
-VSpace = collections.namedtuple(
- "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"])
+VSpace = collections.namedtuple("VSpace", [
+ "aggregate_fn", "num_elements_fn", "zeros_fn", "ones_fn", "graph_shape_fn"
+])
def imperative_grad(
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index a0f6be459e..196e20e4d7 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/python/eager/pywrap_tfe.h"
+#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
@@ -889,12 +890,239 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
return static_cast<tensorflow::DataType>(id);
}
+class PyTapeTensor {
+ public:
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ const tensorflow::TensorShape& shape)
+ : id_(id), dtype_(dtype), shape_(shape) {}
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ PyObject* shape)
+ : id_(id), dtype_(dtype), shape_(shape) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ PyTapeTensor(const PyTapeTensor& other) {
+ id_ = other.id_;
+ dtype_ = other.dtype_;
+ shape_ = other.shape_;
+ if (shape_.index() == 1) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ }
+
+ ~PyTapeTensor() {
+ if (shape_.index() == 1) {
+ Py_DECREF(absl::get<1>(shape_));
+ }
+ }
+ PyObject* GetShape() const;
+ PyObject* GetDType() const { return PyLong_FromLong(dtype_); }
+ tensorflow::int64 GetID() const { return id_; }
+
+ private:
+ tensorflow::int64 id_;
+ tensorflow::DataType dtype_;
+ absl::variant<tensorflow::TensorShape, PyObject*> shape_;
+};
+
+class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
+ public:
+ explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
+ Py_INCREF(py_vspace_);
+ }
+
+ tensorflow::Status Initialize() {
+ num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
+ if (num_elements_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
+ if (aggregate_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
+ if (zeros_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
+ if (ones_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
+ if (graph_shape_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ return tensorflow::Status::OK();
+ }
+
+ ~PyVSpace() override {
+ Py_XDECREF(num_elements_);
+ Py_XDECREF(aggregate_fn_);
+ Py_XDECREF(zeros_fn_);
+ Py_XDECREF(ones_fn_);
+ Py_XDECREF(graph_shape_fn_);
+
+ Py_DECREF(py_vspace_);
+ }
+
+ tensorflow::int64 NumElements(PyObject* tensor) const final {
+ if (EagerTensor_CheckExact(tensor)) {
+ return PyEagerTensor_NumElements(tensor);
+ }
+ PyObject* arglist =
+ Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
+ PyObject* result = PyEval_CallObject(num_elements_, arglist);
+ Py_DECREF(arglist);
+ if (result == nullptr) {
+ // The caller detects whether a python exception has been raised.
+ return -1;
+ }
+ tensorflow::int64 r = MakeInt(result);
+ Py_DECREF(result);
+ return r;
+ }
+
+ PyObject* AggregateGradients(
+ tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
+ PyObject* list = PyList_New(gradient_tensors.size());
+ for (int i = 0; i < gradient_tensors.size(); ++i) {
+ // Note: stealing a reference to the gradient tensors.
+ CHECK(gradient_tensors[i] != nullptr);
+ CHECK(gradient_tensors[i] != Py_None);
+ PyList_SET_ITEM(list, i,
+ reinterpret_cast<PyObject*>(gradient_tensors[i]));
+ }
+ PyObject* arglist = Py_BuildValue("(O)", list);
+ CHECK(arglist != nullptr);
+ PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
+ Py_DECREF(arglist);
+ Py_DECREF(list);
+ return result;
+ }
+
+ void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
+
+ PyObject* Zeros(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return reinterpret_cast<PyObject*>(result);
+ }
+
+ PyObject* Ones(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return result;
+ }
+
+ PyObject* GraphShape(PyObject* tensor) const {
+ PyObject* arg_list = Py_BuildValue("(O)", tensor);
+ PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
+ Py_DECREF(arg_list);
+ return result;
+ }
+
+ tensorflow::Status CallBackwardFunction(
+ PyBackwardFunction* backward_function,
+ tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
+ std::vector<PyObject*>* result) const final {
+ PyObject* grads = PyTuple_New(output_gradients.size());
+ for (int i = 0; i < output_gradients.size(); ++i) {
+ if (output_gradients[i] == nullptr) {
+ Py_INCREF(Py_None);
+ PyTuple_SET_ITEM(grads, i, Py_None);
+ } else {
+ PyTuple_SET_ITEM(grads, i,
+ reinterpret_cast<PyObject*>(output_gradients[i]));
+ }
+ }
+ PyObject* py_result = (*backward_function)(grads);
+ Py_DECREF(grads);
+ if (py_result == nullptr) {
+ return tensorflow::errors::Internal("gradient function threw exceptions");
+ }
+ result->clear();
+ PyObject* seq =
+ PySequence_Fast(py_result, "expected a sequence of gradients");
+ if (seq == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "gradient function did not return a list");
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ VLOG(1) << "Gradient length is " << len;
+ result->reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+ if (item == Py_None) {
+ result->push_back(nullptr);
+ } else {
+ Py_INCREF(item);
+ result->push_back(item);
+ }
+ }
+ Py_DECREF(seq);
+ Py_DECREF(py_result);
+ return tensorflow::Status::OK();
+ }
+
+ void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
+
+ private:
+ PyObject* py_vspace_;
+
+ PyObject* num_elements_;
+ PyObject* aggregate_fn_;
+ PyObject* zeros_fn_;
+ PyObject* ones_fn_;
+ PyObject* graph_shape_fn_;
+};
+PyVSpace* py_vspace = nullptr;
+
+PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
+ if (py_vspace != nullptr) {
+ delete py_vspace;
+ }
+
+ py_vspace = new PyVSpace(e);
+ auto status = py_vspace->Initialize();
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ delete py_vspace;
+ return nullptr;
+ }
+
+ Py_RETURN_NONE;
+}
+
+PyObject* PyTapeTensor::GetShape() const {
+ if (shape_.index() == 0) {
+ auto& shape = absl::get<0>(shape_);
+ PyObject* py_shape = PyTuple_New(shape.dims());
+ for (int i = 0; i < shape.dims(); ++i) {
+ PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
+ }
+
+ return py_shape;
+ }
+
+ return py_vspace->GraphShape(absl::get<1>(shape_));
+}
+
class GradientTape
- : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
+ : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
public:
explicit GradientTape(bool persistent, bool watch_accessed_variables)
- : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
- persistent),
+ : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor>(persistent),
watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {
@@ -1175,7 +1403,24 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
-static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
+bool ListContainsNone(PyObject* list) {
+ if (list == Py_None) return true;
+ tensorflow::Safe_PyObjectPtr seq(
+ PySequence_Fast(list, "expected a sequence"));
+ if (seq == nullptr) {
+ return false;
+ }
+
+ int len = PySequence_Size(list);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
+ if (item == Py_None) return true;
+ }
+
+ return false;
+}
+
+static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
tensorflow::int64 id = PyEagerTensor_ID(tensor);
@@ -1183,16 +1428,16 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
const tensorflow::Status status = t->handle->Shape(&tensor_shape);
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype,
- tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
} else {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape};
+ return PyTapeTensor(id, t->handle->dtype, tensor_shape);
}
}
tensorflow::int64 id = FastTensorId(tensor);
if (PyErr_Occurred()) {
- return tensorflow::eager::TapeTensor{
- id, static_cast<tensorflow::DataType>(0), tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
@@ -1200,16 +1445,21 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
Py_DECREF(dtype_enum);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
static char _shape_tuple[] = "_shape_tuple";
PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
+
+ if (ListContainsNone(shape_tuple)) {
+ return PyTapeTensor(id, dtype, tensor);
+ }
+
auto l = MakeIntList(shape_tuple);
Py_DECREF(shape_tuple);
// Replace -1, which represents accidental Nones which can occur in graph mode
@@ -1220,7 +1470,7 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
}
}
tensorflow::TensorShape shape(l);
- return tensorflow::eager::TapeTensor{id, dtype, shape};
+ return PyTapeTensor(id, dtype, shape);
}
std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
@@ -1286,7 +1536,7 @@ void TapeSetRecordOperation(
const std::vector<tensorflow::DataType>& input_dtypes,
const std::function<PyBackwardFunction*()>& backward_function_getter,
const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
- std::vector<tensorflow::eager::TapeTensor> output_info;
+ std::vector<PyTapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
int len = PySequence_Size(output_tensors);
@@ -1362,180 +1612,6 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
}
}
-class PyVSpace
- : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
- public:
- explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
- Py_INCREF(py_vspace_);
- }
-
- tensorflow::Status Initialize() {
- num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
- if (num_elements_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
- if (aggregate_fn_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- zeros_ = PyObject_GetAttrString(py_vspace_, "zeros");
- if (zeros_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- ones_ =
- PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones");
- if (ones_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- return tensorflow::Status::OK();
- }
-
- ~PyVSpace() override {
- Py_XDECREF(num_elements_);
- Py_XDECREF(aggregate_fn_);
- Py_XDECREF(zeros_);
- Py_XDECREF(ones_);
-
- Py_DECREF(py_vspace_);
- }
-
- tensorflow::int64 NumElements(PyObject* tensor) const final {
- if (EagerTensor_CheckExact(tensor)) {
- return PyEagerTensor_NumElements(tensor);
- }
- PyObject* arglist =
- Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
- PyObject* result = PyEval_CallObject(num_elements_, arglist);
- Py_DECREF(arglist);
- if (result == nullptr) {
- // The caller detects whether a python exception has been raised.
- return -1;
- }
- tensorflow::int64 r = MakeInt(result);
- Py_DECREF(result);
- return r;
- }
-
- PyObject* AggregateGradients(
- tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
- PyObject* list = PyList_New(gradient_tensors.size());
- for (int i = 0; i < gradient_tensors.size(); ++i) {
- // Note: stealing a reference to the gradient tensors.
- CHECK(gradient_tensors[i] != nullptr);
- CHECK(gradient_tensors[i] != Py_None);
- PyList_SET_ITEM(list, i,
- reinterpret_cast<PyObject*>(gradient_tensors[i]));
- }
- PyObject* arglist = Py_BuildValue("(O)", list);
- CHECK(arglist != nullptr);
- PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
- Py_DECREF(arglist);
- Py_DECREF(list);
- return result;
- }
-
- void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
-
- PyObject* Zeros(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(zeros_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return reinterpret_cast<PyObject*>(result);
- }
-
- PyObject* Ones(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(ones_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return result;
- }
-
- tensorflow::Status CallBackwardFunction(
- PyBackwardFunction* backward_function,
- tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
- std::vector<PyObject*>* result) const final {
- PyObject* grads = PyTuple_New(output_gradients.size());
- for (int i = 0; i < output_gradients.size(); ++i) {
- if (output_gradients[i] == nullptr) {
- Py_INCREF(Py_None);
- PyTuple_SET_ITEM(grads, i, Py_None);
- } else {
- PyTuple_SET_ITEM(grads, i,
- reinterpret_cast<PyObject*>(output_gradients[i]));
- }
- }
- PyObject* py_result = (*backward_function)(grads);
- Py_DECREF(grads);
- if (py_result == nullptr) {
- return tensorflow::errors::Internal("gradient function threw exceptions");
- }
- result->clear();
- PyObject* seq =
- PySequence_Fast(py_result, "expected a sequence of gradients");
- if (seq == nullptr) {
- return tensorflow::errors::InvalidArgument(
- "gradient function did not return a list");
- }
- int len = PySequence_Fast_GET_SIZE(seq);
- VLOG(1) << "Gradient length is " << len;
- result->reserve(len);
- for (int i = 0; i < len; ++i) {
- PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
- if (item == Py_None) {
- result->push_back(nullptr);
- } else {
- Py_INCREF(item);
- result->push_back(item);
- }
- }
- Py_DECREF(seq);
- Py_DECREF(py_result);
- return tensorflow::Status::OK();
- }
-
- void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
-
- private:
- PyObject* py_vspace_;
-
- PyObject* num_elements_;
- PyObject* aggregate_fn_;
- PyObject* zeros_;
- PyObject* ones_;
-};
-PyVSpace* py_vspace = nullptr;
-
-PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
- if (py_vspace != nullptr) {
- delete py_vspace;
- }
-
- py_vspace = new PyVSpace(e);
- auto status = py_vspace->Initialize();
- if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- delete py_vspace;
- return nullptr;
- }
-
- Py_RETURN_NONE;
-}
-
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {