aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-06-02 05:54:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-02 07:04:04 -0700
commit6db2dc0bdb6f7322f45399d936aa45ba8e8c4d82 (patch)
treea74d4d63221c3d2fcb1bb9bd61b4f1770d8e5daa
parent4d13fa7ef3ef3fd07a3099283437816ac0daeafe (diff)
Removing initialization_done in copy_tensor.cc.
Change: 123860431
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc9
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.h15
2 files changed, 11 insertions, 13 deletions
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index 00f3f17d78..3b2f48fbb5 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include <atomic>
#include <vector>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
@@ -23,8 +24,6 @@ limitations under the License.
namespace tensorflow {
namespace {
-static bool initialization_done = false;
-
struct RegistrationInfo {
RegistrationInfo(DeviceType s, DeviceType r, CopyTensor::CopyFunction cf)
: sender_device_type(s), receiver_device_type(r), copy_function(cf) {}
@@ -51,7 +50,6 @@ void CopyTensor::ViaDMA(const string& edge_name,
const AllocatorAttributes dst_alloc_attr,
const Tensor* input, Tensor* output,
StatusCallback done) {
- initialization_done = true;
port::Tracing::ScopedAnnotation annotation(edge_name);
VLOG(1) << "Copy " << edge_name;
@@ -110,11 +108,6 @@ void CopyTensor::ViaDMA(const string& edge_name,
Status CopyTensor::Register(DeviceType sender_device_type,
DeviceType receiver_device_type,
CopyFunction copy_function) {
- if (initialization_done) {
- return errors::FailedPrecondition(
- "May only register CopyTensor functions during before the first tensor "
- "is copied.");
- }
std::vector<RegistrationInfo>* registry = MutableRegistry();
registry->emplace_back(sender_device_type, receiver_device_type,
copy_function);
diff --git a/tensorflow/core/common_runtime/copy_tensor.h b/tensorflow/core/common_runtime/copy_tensor.h
index 9e2003c9b1..140f32019e 100644
--- a/tensorflow/core/common_runtime/copy_tensor.h
+++ b/tensorflow/core/common_runtime/copy_tensor.h
@@ -48,12 +48,9 @@ class CopyTensor {
const AllocatorAttributes dst_alloc_attr,
const Tensor* input, Tensor* output, StatusCallback done);
- // Register a function for copying between two specific DeviceTypes.
- static Status Register(DeviceType sender_device_type,
- DeviceType receiver_device_type,
- CopyFunction copy_function);
-
// Object used to call Register() at static-initialization time.
+ // Note: This should only ever be used as a global-static object; no stack
+ // or heap instances.
class Registration {
public:
Registration(DeviceType sender_device_type, DeviceType receiver_device_type,
@@ -62,6 +59,14 @@ class CopyTensor {
Register(sender_device_type, receiver_device_type, copy_function));
}
};
+
+ private:
+ // Register a function for copying between two specific DeviceTypes.
+ // Note: This should only be called via the constructor of
+ // CopyTensor::Registration.
+ static Status Register(DeviceType sender_device_type,
+ DeviceType receiver_device_type,
+ CopyFunction copy_function);
};
} // namespace tensorflow