aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/plugin/executor/executor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/plugin/executor/executor.h')
-rw-r--r--tensorflow/compiler/plugin/executor/executor.h213
1 files changed, 213 insertions, 0 deletions
diff --git a/tensorflow/compiler/plugin/executor/executor.h b/tensorflow/compiler/plugin/executor/executor.h
new file mode 100644
index 0000000000..32fdb157e4
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/executor.h
@@ -0,0 +1,213 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Declares the ExecutorExecutor class, which is a CPU-only implementation of
+// the StreamExecutor interface. For now, this is used for testing and to
+// examine the performance of host-based StreamExecutor code.
+#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_
+#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_
+
+#include "tensorflow/stream_executor/host/host_stream.h"
+#include "tensorflow/stream_executor/host/host_timer.h"
+
+#include "tensorflow/compiler/xla/shape_util.h"
+
+#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/rng.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+#include <list>
+#include <mutex>
+
+namespace perftools {
+namespace gputools {
+namespace executorplugin {
+
+using Args = tensorflow::gtl::ArraySlice<DeviceMemoryBase>;
+
+class ExecutorExecutor : public internal::StreamExecutorInterface {
+ public:
+ explicit ExecutorExecutor(const PluginConfig &plugin_config);
+ ~ExecutorExecutor() override;
+
+ port::Status Init(int device_ordinal, DeviceOptions device_options) override {
+ return port::Status::OK();
+ }
+
+ bool GetKernel(const MultiKernelLoaderSpec &spec,
+ KernelBase *kernel) override {
+ return false;
+ }
+ bool Launch(Stream *stream, const ThreadDim &thread_dims,
+ const BlockDim &block_dims, const KernelBase &kernel,
+ const KernelArgsArrayBase &args) override {
+ return false;
+ }
+
+ void *Allocate(uint64 size) override;
+ void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes,
+ uint64 size_bytes) override;
+ void Deallocate(DeviceMemoryBase *mem) override;
+
+ void *HostMemoryAllocate(uint64 size) override { return new char[size]; }
+ void HostMemoryDeallocate(void *mem) override {
+ delete[] static_cast<char *>(mem);
+ }
+ bool HostMemoryRegister(void *mem, uint64 size) override { return true; }
+ bool HostMemoryUnregister(void *mem) override { return true; }
+
+ bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &pop_src,
+ uint64 size) override;
+ bool Memcpy(Stream *stream, DeviceMemoryBase *pop_dst, const void *host_src,
+ uint64 size) override;
+ bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst,
+ const DeviceMemoryBase &host_src,
+ uint64 size) override {
+ return false;
+ }
+
+ bool MemZero(Stream *stream, DeviceMemoryBase *location,
+ uint64 size) override {
+ return false;
+ }
+ bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern,
+ uint64 size) override {
+ return false;
+ }
+ bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
+ uint64 size) override {
+ return false;
+ }
+
+ // No "synchronize all activity" implemented for this platform at the moment.
+ bool SynchronizeAllActivity() override { return false; }
+ bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override {
+ return false;
+ }
+
+ bool SynchronousMemSet(DeviceMemoryBase *location, int value,
+ uint64 size) override {
+ return false;
+ }
+
+ port::Status SynchronousMemcpy(DeviceMemoryBase *pop_dst,
+ const void *host_src, uint64 size) override;
+ port::Status SynchronousMemcpy(void *host_dst,
+ const DeviceMemoryBase &pop_src,
+ uint64 size) override;
+ port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst,
+ const DeviceMemoryBase &pop_src,
+ uint64 size) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ bool HostCallback(Stream *stream, std::function<void()> callback) override;
+
+ port::Status AllocateEvent(Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ port::Status DeallocateEvent(Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ port::Status RecordEvent(Stream *stream, Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ port::Status WaitForEvent(Stream *stream, Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ Event::Status PollForEventStatus(Event *event) override {
+ return Event::Status::kError;
+ }
+
+ bool AllocateStream(Stream *stream) override { return true; }
+ void DeallocateStream(Stream *stream) override {}
+ bool CreateStreamDependency(Stream *dependent, Stream *other) override;
+
+ bool AllocateTimer(Timer *timer) override { return true; }
+ void DeallocateTimer(Timer *timer) override {}
+ bool StartTimer(Stream *stream, Timer *timer) override;
+ bool StopTimer(Stream *stream, Timer *timer) override;
+
+ bool BlockHostUntilDone(Stream *stream) override;
+
+ int PlatformDeviceCount() override { return 1; }
+
+ bool DeviceMemoryUsage(int64 *free, int64 *total) const override {
+ return false;
+ }
+
+ DeviceDescription *PopulateDeviceDescription() const override;
+
+ port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override {
+ return port::Status::OK();
+ }
+
+ bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override {
+ return true;
+ }
+
+ SharedMemoryConfig GetDeviceSharedMemoryConfig() override {
+ return SharedMemoryConfig::kDefault;
+ }
+
+ port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config) override {
+ return port::Status{port::error::UNIMPLEMENTED,
+ "Shared memory not supported"};
+ }
+
+ std::unique_ptr<internal::EventInterface> CreateEventImplementation()
+ override {
+ return nullptr;
+ }
+
+ std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
+ override {
+ return nullptr;
+ }
+
+ std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
+ override {
+ return std::unique_ptr<internal::StreamInterface>(new host::HostStream());
+ }
+
+ std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
+ return std::unique_ptr<internal::TimerInterface>(new host::HostTimer());
+ }
+
+ port::StatusOr<DeviceMemoryBase> ExecuteGraph(const xla::Shape &shape,
+ Args args);
+
+ private:
+ DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape);
+
+ port::StatusOr<DeviceMemoryBase> AllocateOutputBuffer(
+ const xla::Shape &shape);
+
+ const PluginConfig plugin_config_;
+};
+
+} // namespace executorplugin
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_