aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc
blob: ec769d41f96aa956ac7fedc8929e707a89d2e78d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"

namespace tensorflow {

class RemoteFusedGraphExecuteTest : public OpsTestBase {};

TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithOneDataType) {
  DataTypeVector input_types({DT_FLOAT, DT_FLOAT});
  DataTypeVector output_types({DT_FLOAT});
  TF_ASSERT_OK(
      NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute")
          .Input(FakeInput(2, DT_FLOAT))
          .Attr("Tinputs", input_types)
          .Attr("Toutputs", output_types)
          .Attr("serialized_remote_fused_graph_execute_info", "")
          .Finalize(node_def()));
  TF_ASSERT_OK(InitOp());
  // TODO(satok): Add benchmark
}

TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithWrongDataType) {
  DataTypeVector input_types({DT_INT32, DT_INT32});
  DataTypeVector output_types({DT_FLOAT});
  ASSERT_FALSE(
      NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute")
          .Input(FakeInput(2, DT_FLOAT))
          .Attr("Tinputs", input_types)
          .Attr("Toutputs", output_types)
          .Attr("serialized_remote_fused_graph_execute_info", "")
          .Finalize(node_def())
          .ok());
  // TODO(satok): Add benchmark
}

////////////////////////////
// End-to-end test: Begin //
////////////////////////////
// This test does a end-to-end test for a simple usage of
// RemoteFusedGraphExecuteOp.

constexpr const char* const NAME_A = "a";
constexpr const char* const NAME_B = "b";
constexpr const char* const NAME_A_PLUS_B = "a_plus_b";
constexpr const char* const REMOTE_FUSED_EXECUTE_OP_NODE_NAME =
    "remote_fused_execute_op";
constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME =
    "build_test_remote_fused_graph_executor";

constexpr float NODE_A_VAL = 2.0f;
constexpr float NODE_A_VAL2 = 10.0f;
constexpr float NODE_B_VAL = 3.0f;
constexpr float FLOAT_VALUE_TOLERANCE = 1e-8f;

// Utility functions //
static Output BuildPlaceHolderOp(const string& name, const DataType dt,
                                 const TensorShape& tensor_shape, Scope* root) {
  const Scope& scope = root->WithOpName(name);
  Node* ret;
  const string unique_name = scope.GetUniqueNameForOp("Placeholder");
  NodeBuilder builder = NodeBuilder(unique_name, "Placeholder")
                            .Attr("dtype", dt)
                            .Attr("shape", tensor_shape);
  scope.UpdateBuilder(&builder);
  scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
  CHECK(scope.ok());
  return Output(ret, 0);
}

static Output BuildRemoteFusedGraphExecuteOp(
    const string& name, const std::vector<Output>& output_list,
    const int output_node_count,
    const RemoteFusedGraphExecuteInfo& execute_info, Scope* root) {
  const Scope& scope = root->WithOpName(name);
  Node* ret;
  CHECK(scope.ok());
  auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
  const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");

  DataTypeVector input_types{DT_FLOAT};
  DataTypeVector output_types{DT_FLOAT};

  auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
                     .Input(node_out_list)
                     .Attr("Tinputs", input_types)
                     .Attr("Toutputs", output_types)
                     .Attr("serialized_remote_fused_graph_execute_info",
                           StringPiece(execute_info.SerializeAsString()));
  CHECK(scope.ok());
  scope.UpdateBuilder(&builder);
  scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
  CHECK(scope.ok());
  return Output(ret, 0);
}

static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo(
    const GraphDef& original_graph) {
  RemoteFusedGraphExecuteInfo execute_info;
  execute_info.set_executor_name(REMOTE_FUSED_EXECUTOR_NAME);

  // In this example, simply copy all nodes. Basically, you don't need to add
  // unused node for inference.
  for (const NodeDef& node : original_graph.node()) {
    NodeDef& copied_node = *execute_info.mutable_remote_graph()->add_node();
    copied_node = node;
    // Adding tensor shape type to the node
    // TODO(satok): Use TensorShapeMap to detime tensor shape type
    RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType(
        std::vector<DataType>({DT_FLOAT}),
        std::vector<TensorShape>({TensorShape()}), &copied_node);
  }

  // Add node A as input
  execute_info.add_graph_input_node_name(NAME_A);
  RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_a =
      *execute_info.add_default_graph_input_tensor_shape();
  shape_a.set_dtype(DT_FLOAT);
  // (skip setting shape to shape_a as it's shape is rank = 0.)

  // Add node A + B as output
  execute_info.add_graph_output_node_name(NAME_A_PLUS_B);
  RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_a_plus_b =
      *execute_info.add_default_graph_output_tensor_shape();
  shape_a_plus_b.set_dtype(DT_FLOAT);
  // (skip setting shape to shape_a_plus_b as it's shape is rank = 0.)

  return execute_info;
}

// 1. Create SampleRemoteFusedGraphExecutor to execute your fused graph
class SampleRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
 public:
  int GetVersion() final { return 1; }
  bool Init(const RemoteFusedGraphExecuteInfo& info) final {
    info_ = &info;
    for (const NodeDef& node_def : info.remote_graph().node()) {
      node_def_map_.emplace(node_def.name(), &node_def);
    }
    return true;
  }
  bool Finalize() final { return true; }
  bool SetupGraph() final { return true; }
  bool ExecuteGraph() final {
    CHECK(info_ != nullptr);
    // TODO(satok): Add utilities to implement this function more easily.
    // CAVEAT: This test only handles add op. You can implement here as you
    // like.
    CHECK_EQ(1, info_->graph_input_node_name_size());
    const string& input_node_name = info_->graph_input_node_name(0);
    const Tensor& input_tensor = input_tensor_cache_[input_node_name];
    const float input_val = *input_tensor.scalar<float>().data();
    // TODO(satok): Read NAME_B from node_a_plus_b
    const NodeDef& node_b = *node_def_map_.at(NAME_B);
    const TensorProto* proto = nullptr;
    TF_CHECK_OK(GetNodeAttr(node_b, "value", &proto));
    Tensor const_tensor;
    TF_CHECK_OK(RemoteFusedGraphExecuteUtils::MakeTensorFromProto(
        *proto, &const_tensor));
    const float b_val = *const_tensor.scalar<float>().data();
    Tensor output_a_plus_b(DT_FLOAT, {});
    output_a_plus_b.flat<float>().data()[0] = input_val + b_val;
    output_tensor_buf_.emplace(info_->graph_output_node_name(0),
                               output_a_plus_b);
    return true;
  }

  bool TeardownGraph() final { return true; }

  bool FillInputNode(const string& node_name, const Tensor& tensor) final {
    input_tensor_cache_[node_name] = tensor;
    return true;
  }

  bool ReadOutputNode(const string& node_name,
                      TensorAllocatorFunc tensor_allocator) final {
    // TODO(satok): Specify tensor shape by using default_graph_tensor_shape.
    const Tensor& buffered_output_tensor = output_tensor_buf_.at(node_name);
    const TensorShape& output_shape = buffered_output_tensor.shape();
    Tensor* output_tensor = tensor_allocator(output_shape);
    CHECK_EQ(buffered_output_tensor.dtype(), output_tensor->dtype());
    CHECK(output_tensor->CopyFrom(buffered_output_tensor, output_shape));
    return true;
  }

  Status FuseRemoteGraph(const GraphDef& original_graph_def,
                         const std::vector<string>& /*inputs*/,
                         const std::vector<string>& /*outputs*/,
                         GraphDef* fused_graph_def) final {
    *fused_graph_def = original_graph_def;
    return Status::OK();
  }

  bool IsEnabled() const final { return true; }

 private:
  const RemoteFusedGraphExecuteInfo* info_;
  std::unordered_map<string, Tensor> input_tensor_cache_;
  std::unordered_map<string, const NodeDef*> node_def_map_;
  std::unordered_map<string, Tensor> output_tensor_buf_;
};

// 2. Register a builder of your custom executor
namespace remote_fused_graph_execute_op {
Status BuildRemoteFusedGraphExecutor(
    std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
  executor->reset(new SampleRemoteFusedGraphExecutor());
  return Status::OK();
}

// This class instantiation registers executor to the
// RemoteFusedGraphExecuteOp. This architecture makes executors to be
// pluggable in order not to link unnecessary libraries.
static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
    k_test_remote_fused_graph_executor_build(REMOTE_FUSED_EXECUTOR_NAME,
                                             BuildRemoteFusedGraphExecutor);
}  // namespace remote_fused_graph_execute_op

// 3. Create Graph transform function to fuse your graph
static Status RewriteGraphToFusedGraph(const GraphDef& original_graph,
                                       GraphDef* fused_graph) {
  Scope root = Scope::NewRootScope();
  std::vector<Output> output_list;
  const Output op_a = BuildPlaceHolderOp(NAME_A, DT_FLOAT, {}, &root);
  output_list.emplace_back(op_a);
  const RemoteFusedGraphExecuteInfo execute_info =
      BuildRemoteFusedGraphExecuteInfo(original_graph);
  BuildRemoteFusedGraphExecuteOp(REMOTE_FUSED_EXECUTE_OP_NODE_NAME, output_list,
                                 1, execute_info, &root);
  GraphDef fused_graph_def;
  TF_CHECK_OK(root.ToGraphDef(&fused_graph_def));
  *fused_graph = fused_graph_def;
  return Status::OK();
}

// 4. Register transform function
// You can register transform function by REGISTER_GRAPH_TRANSFORM.
// In this test, we don't use graph transform tool to avoid linking to
// the graph transform library.
// To register transform function, you need to change the interface of
// BuildFusedGraphDefOfAddGraph to
// Status BuildFusedGraphDefOfAddGraph(
// const GraphDef& original_graph, const TransformFuncContext& context,
// GraphDef* output_graph_def);
// Then register the function like:
// REGISTER_GRAPH_TRANSFORM("rewrite_graph", RewriteGraph);

// 5. Fuse the original graph and run the inference the new fused graph
TEST(RemoteFusedExecuteGraphOp, EndToEndTest) {
  // 5.1 Load original graph
  GraphDef original_graph;
  TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
      NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &original_graph));

  // 5.2 Fuse graph
  GraphDef fused_graph;
  TF_ASSERT_OK(RewriteGraphToFusedGraph(original_graph, &fused_graph));

  // 5.3 Setup session
  std::vector<Tensor> output_tensors;
  SessionOptions session_options;
  session_options.env = Env::Default();
  std::unique_ptr<Session> session =
      std::unique_ptr<Session>(NewSession(session_options));
  Status status = session->Create(fused_graph);
  ASSERT_TRUE(status.ok());
  RunOptions run_options;
  run_options.set_trace_level(RunOptions::FULL_TRACE);
  RunMetadata run_metadata;

  // 5.4 Setup input
  Tensor input_a(DT_FLOAT, {});
  input_a.flat<float>().data()[0] = NODE_A_VAL2;
  std::vector<std::pair<string, Tensor>> inputs;
  inputs.emplace_back(NAME_A, input_a);

  // 5.5 Setup output
  const std::vector<string> outputs{REMOTE_FUSED_EXECUTE_OP_NODE_NAME};

  // 5.6 Run inference with all node as output
  status = session->Run(run_options, inputs, outputs, {}, &output_tensors,
                        &run_metadata);
  ASSERT_TRUE(status.ok());

  // 5.7 Check output tensor value
  ASSERT_EQ(1, output_tensors.size());
  EXPECT_NEAR(NODE_A_VAL2 + NODE_B_VAL,
              output_tensors.at(0).flat<float>().data()[0],
              FLOAT_VALUE_TOLERANCE);
}

////////////////////////////
// End-to-end test: End   //
////////////////////////////

}  // namespace tensorflow