aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/node_builder.cc
blob: 8c34323dbe61e064ec486fbf027b0ad75d83a9fc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include "tensorflow/core/graph/node_builder.h"

#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/errors.h"

namespace tensorflow {

NodeBuilder::NodeBuilder(const string& name, const string& op_name,
                         const OpRegistryInterface* op_registry)
    : def_builder_(name, op_name, op_registry) {}

NodeBuilder::NodeBuilder(const string& name, const OpDef* op_def)
    : def_builder_(name, op_def) {}

NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
  inputs_.emplace_back(src_node, src_index);
  DataType dt;
  if (GetOutputType(src_node, src_index, &dt)) {
    def_builder_.Input(src_node->name(), src_index, dt);
  }
  return *this;
}

NodeBuilder& NodeBuilder::Input(NodeOut src) {
  if (src.error) {
    AddIndexError(src.node, src.index);
  } else {
    inputs_.emplace_back(src.node, src.index);
    def_builder_.Input(src.name, src.index, src.dt);
  }
  return *this;
}

NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
  std::vector<NodeDefBuilder::NodeOut> srcs;
  srcs.reserve(src_list.size());
  for (const auto& node_out : src_list) {
    if (node_out.error) {
      AddIndexError(node_out.node, node_out.index);
    } else {
      srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
      inputs_.emplace_back(node_out.node, node_out.index);
    }
  }
  def_builder_.Input(srcs);
  return *this;
}

NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
  control_inputs_.emplace_back(src_node);
  def_builder_.ControlInput(src_node->name());
  return *this;
}

NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
  control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
                         src_nodes.end());
  for (Node* src_node : src_nodes) {
    def_builder_.ControlInput(src_node->name());
  }
  return *this;
}

NodeBuilder& NodeBuilder::Device(const string& device_spec) {
  def_builder_.Device(device_spec);
  return *this;
}

Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
  // In case of error, set *created_node to nullptr.
  if (created_node != nullptr) *created_node = nullptr;
  if (!errors_.empty()) {
    return errors::InvalidArgument(str_util::Join(errors_, "\n"));
  }

  NodeDef node_def;
  TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def));
  TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
  Status status;
  Node* node = graph->AddNode(node_def, &status);
  if (!status.ok()) return status;

  for (size_t i = 0; i < inputs_.size(); ++i) {
    if (inputs_[i].node != nullptr) {  // Skip back edges.
      graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
    }
  }
  for (Node* control_input : control_inputs_) {
    graph->AddControlEdge(control_input, node);
  }
  if (created_node != nullptr) *created_node = node;
  return Status::OK();
}

void NodeBuilder::AddIndexError(Node* node, int i) {
  if (node == nullptr) {
    errors_.emplace_back(
        strings::StrCat("Attempt to add nullptr Node to node with type",
                        def_builder_.op_def().name()));
  } else {
    errors_.emplace_back(
        strings::StrCat("Attempt to add output ", i, " of ", node->name(),
                        " not in range [0, ", node->num_outputs(),
                        ") to node with type ", def_builder_.op_def().name()));
  }
}

bool NodeBuilder::GetOutputType(Node* node, int i, DataType* dt) {
  bool error;
  *dt = SafeGetOutput(node, i, &error);
  if (error) AddIndexError(node, i);
  return !error;
}

}  // namespace tensorflow