aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/threadpool_device.h
blob: 5b0347231f526837ca726084f1ce9c711270e99e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#ifndef TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
#define TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_

#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"

namespace tensorflow {

// CPU device implementation.
class ThreadPoolDevice : public LocalDevice {
 public:
  ThreadPoolDevice(const SessionOptions& options, const string& name,
                   Bytes memory_limit, BusAdjacency bus_adjacency,
                   Allocator* allocator);
  ~ThreadPoolDevice() override;

  void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
  Allocator* GetAllocator(AllocatorAttributes attr) override;
  Status MakeTensorFromProto(const TensorProto& tensor_proto,
                             const AllocatorAttributes alloc_attrs,
                             Tensor* tensor) override;

  Status Sync() override { return Status::OK(); }

 private:
  Allocator* allocator_;  // Not owned
};

}  // namespace tensorflow

#endif  // TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_