diff options
author | 2016-06-02 05:54:12 -0800 | |
---|---|---|
committer | 2016-06-02 07:04:04 -0700 | |
commit | 6db2dc0bdb6f7322f45399d936aa45ba8e8c4d82 (patch) | |
tree | a74d4d63221c3d2fcb1bb9bd61b4f1770d8e5daa | |
parent | 4d13fa7ef3ef3fd07a3099283437816ac0daeafe (diff) |
Removing initialization_done in copy_tensor.cc.
Change: 123860431
-rw-r--r-- | tensorflow/core/common_runtime/copy_tensor.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/copy_tensor.h | 15 |
2 files changed, 11 insertions, 13 deletions
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index 00f3f17d78..3b2f48fbb5 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/copy_tensor.h" +#include <atomic> #include <vector> #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -23,8 +24,6 @@ limitations under the License. namespace tensorflow { namespace { -static bool initialization_done = false; - struct RegistrationInfo { RegistrationInfo(DeviceType s, DeviceType r, CopyTensor::CopyFunction cf) : sender_device_type(s), receiver_device_type(r), copy_function(cf) {} @@ -51,7 +50,6 @@ void CopyTensor::ViaDMA(const string& edge_name, const AllocatorAttributes dst_alloc_attr, const Tensor* input, Tensor* output, StatusCallback done) { - initialization_done = true; port::Tracing::ScopedAnnotation annotation(edge_name); VLOG(1) << "Copy " << edge_name; @@ -110,11 +108,6 @@ void CopyTensor::ViaDMA(const string& edge_name, Status CopyTensor::Register(DeviceType sender_device_type, DeviceType receiver_device_type, CopyFunction copy_function) { - if (initialization_done) { - return errors::FailedPrecondition( - "May only register CopyTensor functions during before the first tensor " - "is copied."); - } std::vector<RegistrationInfo>* registry = MutableRegistry(); registry->emplace_back(sender_device_type, receiver_device_type, copy_function); diff --git a/tensorflow/core/common_runtime/copy_tensor.h b/tensorflow/core/common_runtime/copy_tensor.h index 9e2003c9b1..140f32019e 100644 --- a/tensorflow/core/common_runtime/copy_tensor.h +++ b/tensorflow/core/common_runtime/copy_tensor.h @@ -48,12 +48,9 @@ class CopyTensor { const AllocatorAttributes dst_alloc_attr, const Tensor* input, Tensor* output, StatusCallback done); - // Register a function for copying between two specific DeviceTypes. - static Status Register(DeviceType sender_device_type, - DeviceType receiver_device_type, - CopyFunction copy_function); - // Object used to call Register() at static-initialization time. + // Note: This should only ever be used as a global-static object; no stack + // or heap instances. class Registration { public: Registration(DeviceType sender_device_type, DeviceType receiver_device_type, @@ -62,6 +59,14 @@ class CopyTensor { Register(sender_device_type, receiver_device_type, copy_function)); } }; + + private: + // Register a function for copying between two specific DeviceTypes. + // Note: This should only be called via the constructor of + // CopyTensor::Registration. + static Status Register(DeviceType sender_device_type, + DeviceType receiver_device_type, + CopyFunction copy_function); }; } // namespace tensorflow |