aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_tensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_tensor.h')
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index c54001a999..f7e401c731 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -85,6 +85,24 @@ class XlaTensor {
host_tensor_.reset(new Tensor(tensor));
}
+ // If the tensor's content is not yet defined on 'stream', and there exists an
+ // se::Event declaring when the tensor's content is defined, return it.
+ // Otherwise, return nullptr. If this function returns nullptr then the
+ // tensor's content can be read on 'stream' without additional
+ // synchronization.
+ se::Event* GetDefinitionEvent(se::Stream* stream);
+
+ // Assert that the tensor's content is defined on 'stream' by the time 'event'
+ // triggers.
+ void SetDefinedOn(se::Stream* stream, se::Event event);
+
+ // Assert that the tensor's content is defined on 'stream'. This version does
+ // not provide an event, and must be called *after* SetDefinedOn(Stream,
+ // Event). This call can be read as an assertion that the definition event has
+ // been waited on by 'stream', so further calls to GetDefinitionEvent(stream)
+ // do not need to also wait on the event.
+ void SetDefinedOn(se::Stream* stream);
+
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
static XlaTensor* FromOpaquePointer(void* ptr);
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.
@@ -95,6 +113,14 @@ class XlaTensor {
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
// An optional host tensor value.
std::unique_ptr<Tensor> host_tensor_;
+ // An optional event that is triggered when the tensor's content has been
+ // defined. If this event is nullptr, it is assumed that the tensor's content
+ // is always defined.
+ gtl::optional<se::Event> definition_event_;
+ // A list of all streams for which the tensor's content is defined for any
+ // newly enqueued command.
+ gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
+ mutex mu_;
};
} // namespace tensorflow