From c3014ec19e23e4aad7286b3fac6b25a5fb4a6326 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Wed, 19 Sep 2018 14:54:07 -0700 Subject: Allow the tape tensor to have unknown shapes. This is done by making the TapeTensor a template rather than a concrete struct. PiperOrigin-RevId: 213700425 --- tensorflow/c/eager/tape.h | 118 ++++---- tensorflow/python/eager/BUILD | 1 + tensorflow/python/eager/backprop.py | 5 +- tensorflow/python/eager/backprop_test.py | 12 + tensorflow/python/eager/imperative_grad.py | 5 +- tensorflow/python/eager/pywrap_tfe_src.cc | 458 +++++++++++++++++------------ 6 files changed, 342 insertions(+), 257 deletions(-) (limited to 'tensorflow') diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 49990b6249..41b5b8ff36 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -29,15 +29,8 @@ limitations under the License. namespace tensorflow { namespace eager { -// Information about a tensor. -struct TapeTensor { - int64 id; // Expected to be unique in the lifetime of this process. - DataType dtype; - TensorShape shape; -}; - // Represents an entry in the tape. -template +template struct OpTapeEntry { string op_type; std::vector output_tensor_info; @@ -57,8 +50,8 @@ struct OpTapeEntry { using TensorTape = gtl::FlatMap; // Map from operation-id to tape entry. -template -using OpTape = gtl::FlatMap>; +template +using OpTape = gtl::FlatMap>; // Operations the tape needs to perform on tensors to do backpropagation. Named // "vspace" because a subset of these are related to a vector space, such as @@ -79,7 +72,7 @@ using OpTape = gtl::FlatMap>; // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle // specialization, which is blocked by quite a few things needing to loop back // into python now. -template +template class VSpace { public: virtual ~VSpace() {} @@ -93,10 +86,10 @@ class VSpace { gtl::ArraySlice gradient_tensors) const = 0; // Returns a tensor of the right shape and dtype filled with zeros. - virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Zeros(const TapeTensor& tensor) const = 0; // Returns a Tensor which is filled with ones and like the input. - virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0; + virtual Gradient* Ones(const TapeTensor& tensor) const = 0; // Calls the passed-in backward function. virtual Status CallBackwardFunction( @@ -114,7 +107,7 @@ class VSpace { // Traces the execution of operations, doing eager garbage collection, and // exporting a full trace so other code can do backpropagation. Not thread-safe. -template +template class GradientTape { public: // If `persistent` is true, GradientTape will not eagerly delete backward @@ -134,7 +127,7 @@ class GradientTape { void Watch(int64 tensor_id); void RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, + const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, BackwardFunction* backward_function, @@ -146,17 +139,18 @@ class GradientTape { // once) and produces the gradient of the target tensors with respect to the // source tensors. The output gradients are used if not empty and not // null. The result is populated with one tensor per target element. - Status ComputeGradient(const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_id, - gtl::ArraySlice output_gradients, - std::vector* result); + Status ComputeGradient( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice source_tensor_id, + gtl::ArraySlice output_gradients, + std::vector* result); bool IsPersistent() const { return persistent_; } private: TensorTape tensor_tape_; - OpTape op_tape_; + OpTape op_tape_; int64 next_op_id_{0}; // Map from tensor id to number of remaining usages (i.e. how many entries in @@ -186,8 +180,8 @@ inline bool IsDtypeTrainable(DataType dtype) { } } -template -bool GradientTape::ShouldRecord( +template +bool GradientTape::ShouldRecord( gtl::ArraySlice tensor_ids, gtl::ArraySlice dtypes) { CHECK_EQ(tensor_ids.size(), dtypes.size()); @@ -201,14 +195,15 @@ bool GradientTape::ShouldRecord( return false; } -template -void GradientTape::Watch(int64 tensor_id) { +template +void GradientTape::Watch( + int64 tensor_id) { tensor_tape_.emplace(tensor_id, -1); } -template -void GradientTape::RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, +template +void GradientTape::RecordOperation( + const string& op_type, std::vector& output_tensors, gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, BackwardFunction* backward_function, @@ -229,16 +224,18 @@ void GradientTape::RecordOperation( for (const TapeTensor& o : output_tensors) { // Note: the tensor can have already been watched and hence be in the tape, // so we cannot check that we're inserting it here. - tensor_tape_[o.id] = op_id; - tensor_usage_[o.id] = 1; + tensor_tape_[o.GetID()] = op_id; + tensor_usage_[o.GetID()] = 1; tensors.push_back(o); } - op_tape_[op_id] = OpTapeEntry{ - op_type, tensors, ids, backward_function, backward_function_deleter}; + op_tape_[op_id] = OpTapeEntry{ + op_type, std::move(tensors), ids, backward_function, + backward_function_deleter}; } -template -void GradientTape::DeleteTrace(int64 tensor_id) { +template +void GradientTape::DeleteTrace( + int64 tensor_id) { auto it = tensor_usage_.find(tensor_id); if (it == tensor_usage_.end()) { return; @@ -261,7 +258,7 @@ void GradientTape::DeleteTrace(int64 tensor_id) { auto op_it = op_tape_.find(op_id); CHECK(op_it != op_tape_.end()); for (const auto& output : op_it->second.output_tensor_info) { - if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) { // Found a usage for an output, so cannot delete the op. return; } @@ -304,9 +301,9 @@ void GradientTape::DeleteTrace(int64 tensor_id) { namespace { -template +template struct BackpropInitialState { - OpTape op_tape; + OpTape op_tape; // Map from tensor ID to how many references still exist for this tensor in // the tape. @@ -322,17 +319,17 @@ struct BackpropInitialState { // If `persistent_tape` is false, op_tape is cleared and backwards functions // not needed for gradient computation are deleted. Backwards functions that // are needed, are copied and returned in BackpropInitialState. -template -BackpropInitialState PrepareBackprop( +template +BackpropInitialState PrepareBackprop( gtl::ArraySlice target, const TensorTape& tensor_tape, - OpTape* op_tape, const gtl::FlatSet& sources_set, - bool persistent_tape) { + OpTape* op_tape, + const gtl::FlatSet& sources_set, bool persistent_tape) { std::vector tensor_stack; tensor_stack.reserve(target.size()); for (auto t : target) { tensor_stack.push_back(t); } - BackpropInitialState result; + BackpropInitialState result; while (!tensor_stack.empty()) { int64 tensor_id = tensor_stack.back(); tensor_stack.pop_back(); @@ -383,9 +380,9 @@ BackpropInitialState PrepareBackprop( return result; } -template +template std::vector InitialStack( - const OpTape& op_tape, + const OpTape& op_tape, const gtl::FlatMap& op_missing_tensor) { std::vector result; for (auto& op_entry : op_tape) { @@ -396,13 +393,13 @@ std::vector InitialStack( return result; } -template -Status InitialGradients(const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice output_gradients, - const TensorTape& tensor_tape, - const OpTape& op_tape, - gtl::FlatMap>* result) { +template +Status InitialGradients( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, + const OpTape& op_tape, + gtl::FlatMap>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; if (output_gradients.empty() || output_gradients[i] == nullptr) { @@ -416,11 +413,10 @@ Status InitialGradients(const VSpace& vspace, } bool found = false; for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { - if (op_it->second.output_tensor_info[j].id == id) { + if (op_it->second.output_tensor_info[j].GetID() == id) { found = true; (*result)[id].push_back( - vspace.Ones(op_it->second.output_tensor_info[j].shape, - op_it->second.output_tensor_info[j].dtype)); + vspace.Ones(op_it->second.output_tensor_info[j])); break; } } @@ -469,16 +465,16 @@ gtl::FlatMap>* FunctionsAcceptingNoneForIndicesMap() { constexpr int kMinAggregateCount = 4; constexpr int kMinAggregateBytes = 128 * 1024 * 1024; -template -Status GradientTape::ComputeGradient( - const VSpace& vspace, +template +Status GradientTape::ComputeGradient( + const VSpace& vspace, gtl::ArraySlice target_tensor_ids, gtl::ArraySlice source_tensor_ids, gtl::ArraySlice output_gradients, std::vector* result) { gtl::FlatSet sources_set(source_tensor_ids.begin(), source_tensor_ids.end()); - BackpropInitialState state = PrepareBackprop( + BackpropInitialState state = PrepareBackprop( target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); @@ -522,7 +518,7 @@ Status GradientTape::ComputeGradient( out_gradients.reserve(trace.output_tensor_info.size()); bool any_gradient_nonzero = false; for (int i = 0; i < trace.output_tensor_info.size(); ++i) { - const int64 id = trace.output_tensor_info[i].id; + const int64 id = trace.output_tensor_info[i].GetID(); auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { auto func_name_it = @@ -531,9 +527,7 @@ Status GradientTape::ComputeGradient( func_name_it->second.find(i) != func_name_it->second.end()) { out_gradients.push_back(nullptr); } else { - out_gradients.push_back( - vspace.Zeros(trace.output_tensor_info[i].shape, - trace.output_tensor_info[i].dtype)); + out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i])); } } else { any_gradient_nonzero = true; diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index c1bc27d443..f80256fc2a 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -34,6 +34,7 @@ cc_library( "//tensorflow/python:safe_ptr", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", + "@com_google_absl//absl/types:variant", ], ) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 50a6ce6324..d95e0fe721 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -608,8 +608,9 @@ def _ones(shape, dtype): _default_vspace = imperative_grad.VSpace( num_elements_fn=_num_elements, aggregate_fn=_aggregate_grads, - zeros=_zeros, - ones=_ones) + zeros_fn=_zeros, + ones_fn=_ones, + graph_shape_fn=gen_array_ops.shape) pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace) diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index f938ed5df8..32731747b7 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -1022,6 +1022,18 @@ class BackpropTest(test.TestCase): resource_variable_ops.ResourceVariable(2.0)) self.assertAllEqual(gradients_constants, gradients_variables) + def testUnknownShapes(self): + with context.graph_mode(): + with backprop.GradientTape() as tape: + a = array_ops.placeholder(dtype=dtypes.float32, shape=None) + tape.watch(a) + b = a**3 + + db_da = tape.gradient(b, a) + + with self.cached_session() as sess: + self.assertEqual((8.0, 12.0), sess.run((b, db_da), feed_dict={a: 2.0})) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index 5f027d107c..5f5af4ab6c 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -23,8 +23,9 @@ import collections from tensorflow.python import pywrap_tensorflow -VSpace = collections.namedtuple( - "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"]) +VSpace = collections.namedtuple("VSpace", [ + "aggregate_fn", "num_elements_fn", "zeros_fn", "ones_fn", "graph_shape_fn" +]) def imperative_grad( diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index a0f6be459e..196e20e4d7 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/python/eager/pywrap_tfe.h" +#include "absl/types/variant.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" @@ -889,12 +890,239 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) { return static_cast(id); } +class PyTapeTensor { + public: + PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype, + const tensorflow::TensorShape& shape) + : id_(id), dtype_(dtype), shape_(shape) {} + PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype, + PyObject* shape) + : id_(id), dtype_(dtype), shape_(shape) { + Py_INCREF(absl::get<1>(shape_)); + } + PyTapeTensor(const PyTapeTensor& other) { + id_ = other.id_; + dtype_ = other.dtype_; + shape_ = other.shape_; + if (shape_.index() == 1) { + Py_INCREF(absl::get<1>(shape_)); + } + } + + ~PyTapeTensor() { + if (shape_.index() == 1) { + Py_DECREF(absl::get<1>(shape_)); + } + } + PyObject* GetShape() const; + PyObject* GetDType() const { return PyLong_FromLong(dtype_); } + tensorflow::int64 GetID() const { return id_; } + + private: + tensorflow::int64 id_; + tensorflow::DataType dtype_; + absl::variant shape_; +}; + +class PyVSpace : public tensorflow::eager::VSpace { + public: + explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { + Py_INCREF(py_vspace_); + } + + tensorflow::Status Initialize() { + num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); + if (num_elements_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); + if (aggregate_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn"); + if (zeros_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn"); + if (ones_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn"); + if (graph_shape_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + return tensorflow::Status::OK(); + } + + ~PyVSpace() override { + Py_XDECREF(num_elements_); + Py_XDECREF(aggregate_fn_); + Py_XDECREF(zeros_fn_); + Py_XDECREF(ones_fn_); + Py_XDECREF(graph_shape_fn_); + + Py_DECREF(py_vspace_); + } + + tensorflow::int64 NumElements(PyObject* tensor) const final { + if (EagerTensor_CheckExact(tensor)) { + return PyEagerTensor_NumElements(tensor); + } + PyObject* arglist = + Py_BuildValue("(O)", reinterpret_cast(tensor)); + PyObject* result = PyEval_CallObject(num_elements_, arglist); + Py_DECREF(arglist); + if (result == nullptr) { + // The caller detects whether a python exception has been raised. + return -1; + } + tensorflow::int64 r = MakeInt(result); + Py_DECREF(result); + return r; + } + + PyObject* AggregateGradients( + tensorflow::gtl::ArraySlice gradient_tensors) const final { + PyObject* list = PyList_New(gradient_tensors.size()); + for (int i = 0; i < gradient_tensors.size(); ++i) { + // Note: stealing a reference to the gradient tensors. + CHECK(gradient_tensors[i] != nullptr); + CHECK(gradient_tensors[i] != Py_None); + PyList_SET_ITEM(list, i, + reinterpret_cast(gradient_tensors[i])); + } + PyObject* arglist = Py_BuildValue("(O)", list); + CHECK(arglist != nullptr); + PyObject* result = PyEval_CallObject(aggregate_fn_, arglist); + Py_DECREF(arglist); + Py_DECREF(list); + return result; + } + + void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); } + + PyObject* Zeros(const PyTapeTensor& tensor) const final { + PyObject* py_shape = tensor.GetShape(); + PyObject* py_dtype = tensor.GetDType(); + PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); + PyObject* result = PyEval_CallObject(zeros_fn_, arg_list); + Py_DECREF(arg_list); + Py_DECREF(py_dtype); + Py_DECREF(py_shape); + return reinterpret_cast(result); + } + + PyObject* Ones(const PyTapeTensor& tensor) const final { + PyObject* py_shape = tensor.GetShape(); + PyObject* py_dtype = tensor.GetDType(); + PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); + PyObject* result = PyEval_CallObject(ones_fn_, arg_list); + Py_DECREF(arg_list); + Py_DECREF(py_dtype); + Py_DECREF(py_shape); + return result; + } + + PyObject* GraphShape(PyObject* tensor) const { + PyObject* arg_list = Py_BuildValue("(O)", tensor); + PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list); + Py_DECREF(arg_list); + return result; + } + + tensorflow::Status CallBackwardFunction( + PyBackwardFunction* backward_function, + tensorflow::gtl::ArraySlice output_gradients, + std::vector* result) const final { + PyObject* grads = PyTuple_New(output_gradients.size()); + for (int i = 0; i < output_gradients.size(); ++i) { + if (output_gradients[i] == nullptr) { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(grads, i, Py_None); + } else { + PyTuple_SET_ITEM(grads, i, + reinterpret_cast(output_gradients[i])); + } + } + PyObject* py_result = (*backward_function)(grads); + Py_DECREF(grads); + if (py_result == nullptr) { + return tensorflow::errors::Internal("gradient function threw exceptions"); + } + result->clear(); + PyObject* seq = + PySequence_Fast(py_result, "expected a sequence of gradients"); + if (seq == nullptr) { + return tensorflow::errors::InvalidArgument( + "gradient function did not return a list"); + } + int len = PySequence_Fast_GET_SIZE(seq); + VLOG(1) << "Gradient length is " << len; + result->reserve(len); + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(seq, i); + if (item == Py_None) { + result->push_back(nullptr); + } else { + Py_INCREF(item); + result->push_back(item); + } + } + Py_DECREF(seq); + Py_DECREF(py_result); + return tensorflow::Status::OK(); + } + + void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } + + private: + PyObject* py_vspace_; + + PyObject* num_elements_; + PyObject* aggregate_fn_; + PyObject* zeros_fn_; + PyObject* ones_fn_; + PyObject* graph_shape_fn_; +}; +PyVSpace* py_vspace = nullptr; + +PyObject* TFE_Py_RegisterVSpace(PyObject* e) { + if (py_vspace != nullptr) { + delete py_vspace; + } + + py_vspace = new PyVSpace(e); + auto status = py_vspace->Initialize(); + if (MaybeRaiseExceptionFromStatus(status, nullptr)) { + delete py_vspace; + return nullptr; + } + + Py_RETURN_NONE; +} + +PyObject* PyTapeTensor::GetShape() const { + if (shape_.index() == 0) { + auto& shape = absl::get<0>(shape_); + PyObject* py_shape = PyTuple_New(shape.dims()); + for (int i = 0; i < shape.dims(); ++i) { + PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); + } + + return py_shape; + } + + return py_vspace->GraphShape(absl::get<1>(shape_)); +} + class GradientTape - : public tensorflow::eager::GradientTape { + : public tensorflow::eager::GradientTape { public: explicit GradientTape(bool persistent, bool watch_accessed_variables) - : tensorflow::eager::GradientTape( - persistent), + : tensorflow::eager::GradientTape(persistent), watch_accessed_variables_(watch_accessed_variables) {} virtual ~GradientTape() { @@ -1175,7 +1403,24 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { reinterpret_cast(tape)->tape->Watch(tensor_id); } -static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { +bool ListContainsNone(PyObject* list) { + if (list == Py_None) return true; + tensorflow::Safe_PyObjectPtr seq( + PySequence_Fast(list, "expected a sequence")); + if (seq == nullptr) { + return false; + } + + int len = PySequence_Size(list); + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i); + if (item == Py_None) return true; + } + + return false; +} + +static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); tensorflow::int64 id = PyEagerTensor_ID(tensor); @@ -1183,16 +1428,16 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { const tensorflow::Status status = t->handle->Shape(&tensor_shape); if (MaybeRaiseExceptionFromStatus(status, nullptr)) { - return tensorflow::eager::TapeTensor{id, t->handle->dtype, - tensorflow::TensorShape({})}; + return PyTapeTensor(id, static_cast(0), + tensorflow::TensorShape({})); } else { - return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape}; + return PyTapeTensor(id, t->handle->dtype, tensor_shape); } } tensorflow::int64 id = FastTensorId(tensor); if (PyErr_Occurred()) { - return tensorflow::eager::TapeTensor{ - id, static_cast(0), tensorflow::TensorShape({})}; + return PyTapeTensor(id, static_cast(0), + tensorflow::TensorShape({})); } PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype"); PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum"); @@ -1200,16 +1445,21 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { tensorflow::DataType dtype = static_cast(MakeInt(dtype_enum)); Py_DECREF(dtype_enum); - if (PyErr_Occurred() != nullptr) { - return tensorflow::eager::TapeTensor{id, dtype, - tensorflow::TensorShape({})}; + if (PyErr_Occurred()) { + return PyTapeTensor(id, static_cast(0), + tensorflow::TensorShape({})); } static char _shape_tuple[] = "_shape_tuple"; PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr); - if (PyErr_Occurred() != nullptr) { - return tensorflow::eager::TapeTensor{id, dtype, - tensorflow::TensorShape({})}; + if (PyErr_Occurred()) { + return PyTapeTensor(id, static_cast(0), + tensorflow::TensorShape({})); } + + if (ListContainsNone(shape_tuple)) { + return PyTapeTensor(id, dtype, tensor); + } + auto l = MakeIntList(shape_tuple); Py_DECREF(shape_tuple); // Replace -1, which represents accidental Nones which can occur in graph mode @@ -1220,7 +1470,7 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { } } tensorflow::TensorShape shape(l); - return tensorflow::eager::TapeTensor{id, dtype, shape}; + return PyTapeTensor(id, dtype, shape); } std::vector MakeTensorIDList(PyObject* tensors) { @@ -1286,7 +1536,7 @@ void TapeSetRecordOperation( const std::vector& input_dtypes, const std::function& backward_function_getter, const std::function& backward_function_killer) { - std::vector output_info; + std::vector output_info; PyObject* seq = PySequence_Fast(output_tensors, "expected a sequence of integer tensor ids"); int len = PySequence_Size(output_tensors); @@ -1362,180 +1612,6 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { } } -class PyVSpace - : public tensorflow::eager::VSpace { - public: - explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { - Py_INCREF(py_vspace_); - } - - tensorflow::Status Initialize() { - num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); - if (num_elements_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); - if (aggregate_fn_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - zeros_ = PyObject_GetAttrString(py_vspace_, "zeros"); - if (zeros_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - ones_ = - PyObject_GetAttrString(reinterpret_cast(py_vspace_), "ones"); - if (ones_ == nullptr) { - return tensorflow::errors::InvalidArgument("invalid vspace"); - } - return tensorflow::Status::OK(); - } - - ~PyVSpace() override { - Py_XDECREF(num_elements_); - Py_XDECREF(aggregate_fn_); - Py_XDECREF(zeros_); - Py_XDECREF(ones_); - - Py_DECREF(py_vspace_); - } - - tensorflow::int64 NumElements(PyObject* tensor) const final { - if (EagerTensor_CheckExact(tensor)) { - return PyEagerTensor_NumElements(tensor); - } - PyObject* arglist = - Py_BuildValue("(O)", reinterpret_cast(tensor)); - PyObject* result = PyEval_CallObject(num_elements_, arglist); - Py_DECREF(arglist); - if (result == nullptr) { - // The caller detects whether a python exception has been raised. - return -1; - } - tensorflow::int64 r = MakeInt(result); - Py_DECREF(result); - return r; - } - - PyObject* AggregateGradients( - tensorflow::gtl::ArraySlice gradient_tensors) const final { - PyObject* list = PyList_New(gradient_tensors.size()); - for (int i = 0; i < gradient_tensors.size(); ++i) { - // Note: stealing a reference to the gradient tensors. - CHECK(gradient_tensors[i] != nullptr); - CHECK(gradient_tensors[i] != Py_None); - PyList_SET_ITEM(list, i, - reinterpret_cast(gradient_tensors[i])); - } - PyObject* arglist = Py_BuildValue("(O)", list); - CHECK(arglist != nullptr); - PyObject* result = PyEval_CallObject(aggregate_fn_, arglist); - Py_DECREF(arglist); - Py_DECREF(list); - return result; - } - - void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); } - - PyObject* Zeros(tensorflow::TensorShape shape, - tensorflow::DataType dtype) const final { - PyObject* py_shape = PyTuple_New(shape.dims()); - for (int i = 0; i < shape.dims(); ++i) { - PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); - } - PyObject* py_dtype = PyLong_FromLong(static_cast(dtype)); - PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); - PyObject* result = PyEval_CallObject(zeros_, arg_list); - Py_DECREF(arg_list); - Py_DECREF(py_dtype); - Py_DECREF(py_shape); - return reinterpret_cast(result); - } - - PyObject* Ones(tensorflow::TensorShape shape, - tensorflow::DataType dtype) const final { - PyObject* py_shape = PyTuple_New(shape.dims()); - for (int i = 0; i < shape.dims(); ++i) { - PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); - } - PyObject* py_dtype = PyLong_FromLong(static_cast(dtype)); - PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); - PyObject* result = PyEval_CallObject(ones_, arg_list); - Py_DECREF(arg_list); - Py_DECREF(py_dtype); - Py_DECREF(py_shape); - return result; - } - - tensorflow::Status CallBackwardFunction( - PyBackwardFunction* backward_function, - tensorflow::gtl::ArraySlice output_gradients, - std::vector* result) const final { - PyObject* grads = PyTuple_New(output_gradients.size()); - for (int i = 0; i < output_gradients.size(); ++i) { - if (output_gradients[i] == nullptr) { - Py_INCREF(Py_None); - PyTuple_SET_ITEM(grads, i, Py_None); - } else { - PyTuple_SET_ITEM(grads, i, - reinterpret_cast(output_gradients[i])); - } - } - PyObject* py_result = (*backward_function)(grads); - Py_DECREF(grads); - if (py_result == nullptr) { - return tensorflow::errors::Internal("gradient function threw exceptions"); - } - result->clear(); - PyObject* seq = - PySequence_Fast(py_result, "expected a sequence of gradients"); - if (seq == nullptr) { - return tensorflow::errors::InvalidArgument( - "gradient function did not return a list"); - } - int len = PySequence_Fast_GET_SIZE(seq); - VLOG(1) << "Gradient length is " << len; - result->reserve(len); - for (int i = 0; i < len; ++i) { - PyObject* item = PySequence_Fast_GET_ITEM(seq, i); - if (item == Py_None) { - result->push_back(nullptr); - } else { - Py_INCREF(item); - result->push_back(item); - } - } - Py_DECREF(seq); - Py_DECREF(py_result); - return tensorflow::Status::OK(); - } - - void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } - - private: - PyObject* py_vspace_; - - PyObject* num_elements_; - PyObject* aggregate_fn_; - PyObject* zeros_; - PyObject* ones_; -}; -PyVSpace* py_vspace = nullptr; - -PyObject* TFE_Py_RegisterVSpace(PyObject* e) { - if (py_vspace != nullptr) { - delete py_vspace; - } - - py_vspace = new PyVSpace(e); - auto status = py_vspace->Initialize(); - if (MaybeRaiseExceptionFromStatus(status, nullptr)) { - delete py_vspace; - return nullptr; - } - - Py_RETURN_NONE; -} - std::vector MakeTensorList(PyObject* tensors) { PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); if (seq == nullptr) { -- cgit v1.2.3