aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/client/client_session.h
blob: 7dd653eec4ec729b652cb779d06e820bfb437b3c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
#define TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/public/session_options.h"

namespace tensorflow {

/// @addtogroup core
/// @{

/// A `ClientSession` object lets the caller drive the evaluation of the
/// TensorFlow graph constructed with the C++ API.
///
/// Example:
///
///     Scope root = Scope::NewRootScope();
///     auto a = Placeholder(root, DT_INT32);
///     auto c = Add(root, a, {41});
///
///     ClientSession session(root);
///     std::vector<Tensor> outputs;
///
///     Status s = session.Run({ {a, {1}} }, {c}, &outputs);
///     if (!s.ok()) { ... }
class ClientSession {
 public:
  /// A data type to represent feeds to a Run call.
  ///
  /// This is a map of `Output` objects returned by op-constructors to the value
  /// to feed them with. See `Input::Initializer` for details on what can be
  /// used as feed values.
  typedef std::unordered_map<Output, Input::Initializer, OutputHash> FeedType;

  /// Create a new session to evaluate the graph contained in `scope` by
  /// connecting to the TensorFlow runtime specified by `target`.
  ClientSession(const Scope& scope, const string& target);

  /// Same as above, but use the empty string ("") as the target specification.
  ClientSession(const Scope& scope);

  /// Create a new session, configuring it with `session_options`.
  ClientSession(const Scope& scope, const SessionOptions& session_options);

  ~ClientSession();

  /// Evaluate the tensors in `fetch_outputs`. The values are returned as
  /// `Tensor` objects in `outputs`. The number and order of `outputs` will
  /// match `fetch_outputs`.
  Status Run(const std::vector<Output>& fetch_outputs,
             std::vector<Tensor>* outputs) const;

  /// Same as above, but use the mapping in `inputs` as feeds.
  Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs,
             std::vector<Tensor>* outputs) const;

  /// Same as above. Additionally runs the operations ins `run_outputs`.
  Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs,
             const std::vector<Operation>& run_outputs,
             std::vector<Tensor>* outputs) const;

  /// Use `run_options` to turn on performance profiling. `run_metadata`, if not
  /// null, is filled in with the profiling results.
  Status Run(const RunOptions& run_options, const FeedType& inputs,
             const std::vector<Output>& fetch_outputs,
             const std::vector<Operation>& run_outputs,
             std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;

  /// \brief A handle to a subgraph, created with
  /// `ClientSession::MakeCallable()`.
  typedef int64 CallableHandle;

  /// \brief Creates a `handle` for invoking the subgraph defined by
  /// `callable_options`.
  /// NOTE: This API is still experimental and may change.
  Status MakeCallable(const CallableOptions& callable_options,
                      CallableHandle* out_handle);

  /// \brief Invokes the subgraph named by `handle` with the given options and
  /// input tensors.
  ///
  /// The order of tensors in `feed_tensors` must match the order of names in
  /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
  /// match the order of names in `CallableOptions::fetch()` when this subgraph
  /// was created.
  /// NOTE: This API is still experimental and may change.
  Status RunCallable(CallableHandle handle,
                     const std::vector<Tensor>& feed_tensors,
                     std::vector<Tensor>* fetch_tensors,
                     RunMetadata* run_metadata);

  /// \brief Releases resources associated with the given `handle` in this
  /// session.
  /// NOTE: This API is still experimental and may change.
  Status ReleaseCallable(CallableHandle handle);

 private:
  class Impl;
  std::unique_ptr<Impl> impl_;
  Impl* impl() { return impl_.get(); }
  const Impl* impl() const { return impl_.get(); }
};

/// @}

}  // end namespace tensorflow

#endif  // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_