aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/local_session.h
blob: 453cfdde479fe9ca29a80f34abad735c47b2da56 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#ifndef TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_
#define TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/status.h"

namespace tensorflow {

class Device;

class LocalSession : public Session {
 public:
  // Takes ownership of 'device_mgr'.
  LocalSession(const SessionOptions& options, const DeviceMgr* device_mgr);
  ~LocalSession() override;

  ::tensorflow::Status Create(const GraphDef& graph) override;
  ::tensorflow::Status Extend(const GraphDef& graph) override;
  ::tensorflow::Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
                           const std::vector<string>& output_names,
                           const std::vector<string>& target_nodes,
                           std::vector<Tensor>* outputs) override;
  ::tensorflow::Status Close() override;

 private:
  struct ExecutorsAndKeys {
    std::unordered_map<string, Executor*> device_executors;
    std::unordered_map<string, string> input_keys;
    std::unordered_map<string, string> output_keys;

    ~ExecutorsAndKeys() {
      for (auto it : device_executors) {
        delete it.second;
      }
    }
  };

  // Retrieves an already existing set of executors to run 'inputs' and
  // 'outputs', or creates and caches them for future use.
  ::tensorflow::Status GetOrCreateExecutors(
      gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
      gtl::ArraySlice<string> target_nodes,
      ExecutorsAndKeys** executors_and_keys);

  // Creates several graphs given the existing graph_def_ and the
  // input feeds and fetches, given 'devices'.
  ::tensorflow::Status CreateGraphs(
      gtl::ArraySlice<string> feeds, gtl::ArraySlice<string> fetches,
      gtl::ArraySlice<string> target_nodes,
      std::unordered_map<string, Graph*>* outputs);

  ::tensorflow::Status ExtendLocked(const GraphDef& graph)
      EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);

  const SessionOptions options_;

  // Device structures.
  const std::unique_ptr<const DeviceMgr> device_mgr_;
  std::vector<Device*> devices_;           // not owned
  DeviceSet device_set_;

  string session_handle_;
  bool graph_created_ GUARDED_BY(graph_def_lock_) = false;

  mutex graph_def_lock_;
  GraphDef graph_def_ GUARDED_BY(graph_def_lock_);

  mutex executor_lock_;  // protects executors_
  // Holds mappings from signature to the executors that process
  // it. The reason for a level of indirection around mapped_type is
  // to guarantee address stability.
  std::unordered_map<string, ExecutorsAndKeys*> executors_
      GUARDED_BY(executor_lock_);

  CancellationManager* cancellation_manager_;

  // Saves and restores device placements for stateful nodes.
  mutex mu_;
  void SaveStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
  void RestoreStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
  // Map of placed stateful nodes, i.e. nodes for which is_stateful()
  // is true, such as "params" and "queue" nodes.  Once placed these
  // nodes can not be moved to a different device.  Maps node names to
  // device names.
  std::unordered_map<string, string> stateful_placements_ GUARDED_BY(mu_);

  // For generating unique names.
  int64 name_counter_ GUARDED_BY(mu_) = 0;

  TF_DISALLOW_COPY_AND_ASSIGN(LocalSession);
};

}  // end namespace tensorflow

#endif  // TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_