diff options
-rw-r--r-- | tensorflow/python/BUILD | 13 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.cc | 1 | ||||
-rw-r--r-- | tensorflow/python/lib/core/ndarray_tensor.cc | 9 | ||||
-rw-r--r-- | tensorflow/python/lib/core/ndarray_tensor.h | 16 | ||||
-rw-r--r-- | tensorflow/python/lib/core/safe_ptr.cc | 33 | ||||
-rw-r--r-- | tensorflow/python/lib/core/safe_ptr.h | 41 | ||||
-rw-r--r-- | tensorflow/python/util/py_checkpoint_reader.i | 1 |
7 files changed, 91 insertions, 23 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c1e63c0d85..0d8253bcdc 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -247,12 +247,23 @@ cc_library( ) cc_library( + name = "safe_ptr", + srcs = ["lib/core/safe_ptr.cc"], + hdrs = ["lib/core/safe_ptr.h"], + deps = [ + "//tensorflow/c:c_api", + "//util/python:python_headers", + ], +) + +cc_library( name = "ndarray_tensor", srcs = ["lib/core/ndarray_tensor.cc"], hdrs = ["lib/core/ndarray_tensor.h"], deps = [ ":ndarray_tensor_bridge", ":numpy_lib", + ":safe_ptr", "//tensorflow/c:c_api", "//tensorflow/c:tf_status_helper", "//tensorflow/core:framework", @@ -2860,6 +2871,7 @@ tf_cuda_library( ":ndarray_tensor", ":ndarray_tensor_bridge", ":numpy_lib", + ":safe_ptr", ":test_ops_kernels", "//tensorflow/c:c_api", "//tensorflow/c:tf_status_helper", @@ -2917,6 +2929,7 @@ tf_py_wrap_cc( ":cpp_shape_inference", ":kernel_registry", ":numpy_lib", + ":safe_ptr", ":py_func_lib", ":py_record_reader_lib", ":py_record_writer_lib", diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 72f560fa87..619edbe193 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/util/equal_graph_def.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" +#include "tensorflow/python/lib/core/safe_ptr.h" namespace tensorflow { diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc index f10595648b..b1a5a37924 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.cc +++ b/tensorflow/python/lib/core/ndarray_tensor.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" namespace tensorflow { - namespace { Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr, @@ -434,12 +433,4 @@ Status TensorToNdarray(const Tensor& t, PyObject** ret) { return TF_TensorToPyArray(std::move(tf_tensor), ret); } -Safe_PyObjectPtr make_safe(PyObject* o) { - return Safe_PyObjectPtr(o, Py_DECREF_wrapper); -} - -Safe_TF_TensorPtr make_safe(TF_Tensor* tensor) { - return Safe_TF_TensorPtr(tensor, TF_DeleteTensor); -} - } // namespace tensorflow diff --git a/tensorflow/python/lib/core/ndarray_tensor.h b/tensorflow/python/lib/core/ndarray_tensor.h index 57b4ffa7f0..5172d504bd 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.h +++ b/tensorflow/python/lib/core/ndarray_tensor.h @@ -22,22 +22,9 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/python/lib/core/safe_ptr.h" namespace tensorflow { -// Safe container for an owned PyObject. On destruction, the reference count of -// the contained object will be decremented. -inline void Py_DECREF_wrapper(PyObject* o) { Py_DECREF(o); } -// Note: can't use decltype(&Py_DECREF_wrapper) due to SWIG -typedef void (*Py_DECREF_wrapper_type)(PyObject*); -typedef std::unique_ptr<PyObject, Py_DECREF_wrapper_type> Safe_PyObjectPtr; -Safe_PyObjectPtr make_safe(PyObject* o); - -// Safe containers for an owned TF_Tensor. On destruction, the tensor will be -// deleted by TF_DeleteTensor. -// Note: can't use decltype(&TF_DeleteTensor) due to SWIG -typedef void (*TF_DeleteTensor_type)(TF_Tensor*); -typedef std::unique_ptr<TF_Tensor, TF_DeleteTensor_type> Safe_TF_TensorPtr; -Safe_TF_TensorPtr make_safe(TF_Tensor* tensor); Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray); @@ -55,6 +42,7 @@ Status NdarrayToTensor(PyObject* obj, Tensor* ret); // Creates a numpy array in 'ret' which either aliases the content of 't' or has // a copy. Status TensorToNdarray(const Tensor& t, PyObject** ret); + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_ diff --git a/tensorflow/python/lib/core/safe_ptr.cc b/tensorflow/python/lib/core/safe_ptr.cc new file mode 100644 index 0000000000..37d0083848 --- /dev/null +++ b/tensorflow/python/lib/core/safe_ptr.cc @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/python/lib/core/safe_ptr.h" + +namespace tensorflow { +namespace { + +inline void Py_DECREF_wrapper(PyObject* o) { Py_DECREF(o); } + +} // namespace + +Safe_PyObjectPtr make_safe(PyObject* o) { + return Safe_PyObjectPtr(o, Py_DECREF_wrapper); +} + +Safe_TF_TensorPtr make_safe(TF_Tensor* tensor) { + return Safe_TF_TensorPtr(tensor, TF_DeleteTensor); +} + +} // namespace tensorflow diff --git a/tensorflow/python/lib/core/safe_ptr.h b/tensorflow/python/lib/core/safe_ptr.h new file mode 100644 index 0000000000..b01f614977 --- /dev/null +++ b/tensorflow/python/lib/core/safe_ptr.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ +#define THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ + +#include <memory> +#include <Python.h> + +#include "tensorflow/c/c_api.h" + +namespace tensorflow { + +// Safe container for an owned PyObject. On destruction, the reference count of +// the contained object will be decremented. +typedef void (*Py_DECREF_wrapper_type)(PyObject*); +typedef std::unique_ptr<PyObject, Py_DECREF_wrapper_type> Safe_PyObjectPtr; +Safe_PyObjectPtr make_safe(PyObject* o); + +// Safe containers for an owned TF_Tensor. On destruction, the tensor will be +// deleted by TF_DeleteTensor. +// Note: can't use decltype(&TF_DeleteTensor) due to SWIG +typedef void (*TF_DeleteTensor_type)(TF_Tensor*); +typedef std::unique_ptr<TF_Tensor, TF_DeleteTensor_type> Safe_TF_TensorPtr; +Safe_TF_TensorPtr make_safe(TF_Tensor* tensor); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_ diff --git a/tensorflow/python/util/py_checkpoint_reader.i b/tensorflow/python/util/py_checkpoint_reader.i index 6f761a4bff..1d20f9756f 100644 --- a/tensorflow/python/util/py_checkpoint_reader.i +++ b/tensorflow/python/util/py_checkpoint_reader.i @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" #include "tensorflow/python/lib/core/py_func.h" +#include "tensorflow/python/lib/core/safe_ptr.h" %} %typemap(out) const tensorflow::checkpoint::TensorSliceReader::VarToShapeMap& { |