aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/executor_cache.h
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-07-31 12:58:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-31 13:02:31 -0700
commit122750a879f508342b53ee2eaecfde1656981280 (patch)
tree71890cf506fa2d80c0b6a521e7249d7167403731 /tensorflow/stream_executor/executor_cache.h
parent7ebed6678c2f378a5e28c8bade99866718dbdc7e (diff)
[SE] Make ExecutorCache thread-safe, change ExecutorCache::Insert to ExecutorCache::GetOrCreate. Add support for creating Executors for different device ordinals in parallel.
[XLA] Create Executors in parallel. PiperOrigin-RevId: 163734988
Diffstat (limited to 'tensorflow/stream_executor/executor_cache.h')
-rw-r--r--tensorflow/stream_executor/executor_cache.h40
1 files changed, 31 insertions, 9 deletions
diff --git a/tensorflow/stream_executor/executor_cache.h b/tensorflow/stream_executor/executor_cache.h
index 97bfaaecc9..12f2275f6d 100644
--- a/tensorflow/stream_executor/executor_cache.h
+++ b/tensorflow/stream_executor/executor_cache.h
@@ -16,40 +16,62 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
+#include <functional>
+#include <map>
+
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
namespace perftools {
namespace gputools {
// Utility class to allow Platform objects to manage cached StreamExecutors.
+// Thread-safe.
class ExecutorCache {
public:
ExecutorCache() {}
- // Inserts a new StreamExecutor with the given configuration into the cache.
- // Will not overwrite if called when a matching element is already present.
- port::Status Insert(const StreamExecutorConfig& config,
- std::unique_ptr<StreamExecutor> executor);
+ // Looks up 'config' in the cache. Returns a pointer to the existing executor,
+ // if already present, or creates it using 'factory', if it does not.
+ // Factories may be executed concurrently for different device ordinals.
+ typedef port::StatusOr<std::unique_ptr<StreamExecutor>> ExecutorFactory();
+ port::StatusOr<StreamExecutor*> GetOrCreate(
+ const StreamExecutorConfig& config,
+ const std::function<ExecutorFactory>& factory);
// Returns a pointer to the described executor (if one with a matching config
// has been created), or a NOT_FOUND status.
port::StatusOr<StreamExecutor*> Get(const StreamExecutorConfig& config);
// Destroys all Executors and clears the cache.
- // Performs no synchronization - undefined behavior may occur if any executors
- // are active!
+ // Performs no synchronization with the executors - undefined behavior may
+ // occur if any executors are active!
void DestroyAllExecutors();
private:
- typedef std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>>
- Entry;
+ // Each Entry contains zero or more cached executors for a device ordinal.
+ struct Entry {
+ ~Entry();
+
+ // Mutex that locks the contents of each entry. The 'mutex_' of the
+ // ExecutorCache class protects both the 'cache_' and the existence of each
+ // Entry, but not the Entry's contents. 'configurations_mutex' protects the
+ // contents of the entry after 'mutex_' has been dropped.
+ mutex configurations_mutex;
+
+ // Vector of cached {config, executor} pairs.
+ std::vector<
+ std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>>>
+ configurations GUARDED_BY(configurations_mutex);
+ };
// Maps ordinal number to a list of cached executors for that ordinal.
// We key off of ordinal (instead of just looking up all fields in the
// StreamExecutorConfig) for a slight improvement in lookup time.
- std::map<int, std::vector<Entry>> cache_;
+ mutex mutex_;
+ std::map<int, Entry> cache_ GUARDED_BY(mutex_);
SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);
};