diff options
Diffstat (limited to 'tensorflow/compiler/jit/xla_tensor.cc')
-rw-r--r-- | tensorflow/compiler/jit/xla_tensor.cc | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 3c44c4ae6d..5dff187fff 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -73,6 +73,36 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, return Status::OK(); } +se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { + mutex_lock lock(mu_); + if (!definition_event_.has_value()) { + return nullptr; + } + + // The set of defined streams is expected to be very small indeed (usually + // 1-2), so a simple linear scan should be fast enough. + if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(), + stream) != streams_defined_on_.end()) { + // stream is in streams_defined_on_; it doesn't need to be waited on. + return nullptr; + } + + return &*definition_event_; +} + +void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) { + mutex_lock lock(mu_); + CHECK(!definition_event_.has_value()) + << "SetDefinedOn must only be called once!"; + definition_event_ = std::move(event); + streams_defined_on_.push_back(stream); +} + +void XlaTensor::SetDefinedOn(se::Stream* stream) { + mutex_lock lock(mu_); + streams_defined_on_.push_back(stream); +} + // The pointer tag, OR-ed into the XlaTensor's address to distinguish it from // device-side tensors, which are either CPU or GPU memory pointers. This works // because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits. |