aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/remote_fused_graph_execute_op.cc
blob: 8a25f5832972e311a656c0c059198e7cd8eb4084 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// See docs in ../ops/remote_fused_graph_ops.cc.

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {
class RemoteFusedGraphExecuteOp : public OpKernel {
 public:
  explicit RemoteFusedGraphExecuteOp(OpKernelConstruction* const ctx)
      : OpKernel(ctx), execute_info_() {
    string serialized_proto;
    OP_REQUIRES_OK(
        ctx, ctx->GetAttr(RemoteFusedGraphExecuteUtils::
                              ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
                          &serialized_proto));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_types_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_types_));
    execute_info_.ParseFromString(serialized_proto);
    if (!execute_info_.executor_name().empty()) {
      const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc* build_func =
          RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(
              execute_info_.executor_name());
      if (build_func != nullptr) {
        TF_CHECK_OK((*build_func)(&remote_fused_graph_executor_));
        CHECK(remote_fused_graph_executor_->IsEnabled());
      } else {
        LOG(ERROR) << "Executor not found for "
                   << execute_info_.executor_name();
      }
    }

    if (remote_fused_graph_executor_) {
      // 1. Initialize remote processor
      remote_fused_graph_executor_->Init(execute_info_);
      // Explicitly clear serialized executor parameter after initialization
      // to release unnecessary memory.
      execute_info_.clear_serialized_executor_parameters();

      // 2. Setup graph in remote processor
      remote_fused_graph_executor_->SetupGraph();
    }
  }

  ~RemoteFusedGraphExecuteOp() final {
    if (remote_fused_graph_executor_) {
      // 6. Teardown graph in remote processor
      remote_fused_graph_executor_->TeardownGraph();

      // 7. Finalize remote processor
      remote_fused_graph_executor_->Finalize();
    }
  }

  void Compute(OpKernelContext* const ctx) final {
    CHECK(ctx != nullptr);
    const int input_count = ctx->num_inputs();
    const int graph_input_count = execute_info_.graph_input_node_name_size();
    CHECK(input_count == graph_input_count &&
          input_count == input_types_.size())
        << "input_count = " << input_count
        << ", gt input count = " << execute_info_.graph_input_node_name_size()
        << ", type count = " << input_types_.size();

    // 3. Send first data type inputs into remote processor
    for (int i = 0; i < graph_input_count; ++i) {
      const Tensor& input_tensor = ctx->input(i);
      const string& input_node_name = execute_info_.graph_input_node_name(i);
      if (remote_fused_graph_executor_) {
        remote_fused_graph_executor_->FillInputNode(input_node_name,
                                                    input_tensor);
      }
    }

    // 4. Execute graph in remote processor
    if (remote_fused_graph_executor_) {
      remote_fused_graph_executor_->ExecuteGraph();
    }

    // 5. Load outputs from remote processor
    const int output_count = ctx->num_outputs();
    CHECK(output_count == execute_info_.graph_output_node_name_size() &&
          output_count == output_types_.size());
    for (int i = 0; i < output_count; ++i) {
      Tensor* output = nullptr;
      const string& output_node_name = execute_info_.graph_output_node_name(i);
      if (remote_fused_graph_executor_) {
        remote_fused_graph_executor_->ReadOutputNode(
            output_node_name,
            [i, &ctx, &output](const TensorShape& shape) -> Tensor* {
              TF_CHECK_OK(ctx->allocate_output(i, shape, &output));
              return output;
            });
      } else {
        // For compatibility purpose, returns an empty tensor with specified
        // data type as output if no executor is used.
        Tensor* output = nullptr;
        TensorShape ts({});
        TF_CHECK_OK(ctx->allocate_output(i, ts, &output));
      }
    }
  }

  bool IsExpensive() final { return true; }

 private:
  RemoteFusedGraphExecuteInfo execute_info_;
  std::unique_ptr<IRemoteFusedGraphExecutor> remote_fused_graph_executor_;
  DataTypeVector input_types_;
  DataTypeVector output_types_;

  TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOp);
};

REGISTER_KERNEL_BUILDER(Name("RemoteFusedGraphExecute").Device(DEVICE_CPU),
                        RemoteFusedGraphExecuteOp);

}  // namespace tensorflow