aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_tensor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_tensor.cc')
-rw-r--r--tensorflow/compiler/jit/xla_tensor.cc30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc
index 3c44c4ae6d..5dff187fff 100644
--- a/tensorflow/compiler/jit/xla_tensor.cc
+++ b/tensorflow/compiler/jit/xla_tensor.cc
@@ -73,6 +73,36 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
return Status::OK();
}
+se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
+ mutex_lock lock(mu_);
+ if (!definition_event_.has_value()) {
+ return nullptr;
+ }
+
+ // The set of defined streams is expected to be very small indeed (usually
+ // 1-2), so a simple linear scan should be fast enough.
+ if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
+ stream) != streams_defined_on_.end()) {
+ // stream is in streams_defined_on_; it doesn't need to be waited on.
+ return nullptr;
+ }
+
+ return &*definition_event_;
+}
+
+void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) {
+ mutex_lock lock(mu_);
+ CHECK(!definition_event_.has_value())
+ << "SetDefinedOn must only be called once!";
+ definition_event_ = std::move(event);
+ streams_defined_on_.push_back(stream);
+}
+
+void XlaTensor::SetDefinedOn(se::Stream* stream) {
+ mutex_lock lock(mu_);
+ streams_defined_on_.push_back(stream);
+}
+
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
// device-side tensors, which are either CPU or GPU memory pointers. This works
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.