diff options
Diffstat (limited to 'tensorflow/compiler/jit/xla_tensor.h')
-rw-r--r-- | tensorflow/compiler/jit/xla_tensor.h | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index c54001a999..f7e401c731 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -85,6 +85,24 @@ class XlaTensor { host_tensor_.reset(new Tensor(tensor)); } + // If the tensor's content is not yet defined on 'stream', and there exists an + // se::Event declaring when the tensor's content is defined, return it. + // Otherwise, return nullptr. If this function returns nullptr then the + // tensor's content can be read on 'stream' without additional + // synchronization. + se::Event* GetDefinitionEvent(se::Stream* stream); + + // Assert that the tensor's content is defined on 'stream' by the time 'event' + // triggers. + void SetDefinedOn(se::Stream* stream, se::Event event); + + // Assert that the tensor's content is defined on 'stream'. This version does + // not provide an event, and must be called *after* SetDefinedOn(Stream, + // Event). This call can be read as an assertion that the definition event has + // been waited on by 'stream', so further calls to GetDefinitionEvent(stream) + // do not need to also wait on the event. + void SetDefinedOn(se::Stream* stream); + // Convert from a raw pointer to an XlaTensor, removing the pointer tag. static XlaTensor* FromOpaquePointer(void* ptr); // Convert to a raw pointer from an XlaTensor, adding the pointer tag. @@ -95,6 +113,14 @@ class XlaTensor { std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_; // An optional host tensor value. std::unique_ptr<Tensor> host_tensor_; + // An optional event that is triggered when the tensor's content has been + // defined. If this event is nullptr, it is assumed that the tensor's content + // is always defined. + gtl::optional<se::Event> definition_event_; + // A list of all streams for which the tensor's content is defined for any + // newly enqueued command. + gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_); + mutex mu_; }; } // namespace tensorflow |