aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler/controller.py
blob: 5677f4f52310dd68dc80c87275b50be95ba86b60 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Controller Class."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import defaultdict


class Controller(object):
  """Controller class."""

  def __init__(self, item, cluster):
    """Controller class initializer.

    Args:
      item: The metagraph to place wrapped in a cluster.
      cluster: A cluster of devices on which to place the item.
    """
    self.item = item

    self._node = {}
    for node in item.metagraph.graph_def.node:
      self._node[node.name] = node

    self._fanout = defaultdict(lambda: [])
    for node in item.metagraph.graph_def.node:
      for fanin in self._get_node_fanin(node):
        self._fanout[fanin.name].append(node)

    important_op_names = item.IdentifyImportantOps(sort_topologically=True)

    # List of important ops (these are the ops to place) sorted in topological
    # order. The order of this collection is deterministic.
    self.important_ops = []
    for name in important_op_names:
      self.important_ops.append(self._node[name])

    self.node_properties = item.GetOpProperties()

    self.cluster = cluster
    self.devices = cluster.ListDevices()

    self.colocation_constraints = item.GetColocationGroups()

    self.placement_constraints = cluster.GetSupportedDevices(item)
    for node_name, dev in self.placement_constraints.items():
      if len(dev) == 1:
        # Place the node on the supported device
        node = self._node[node_name]
        node.device = dev[0]
        fanout = self.get_node_fanout(node)
        # Update the fanout of the fanin to bypass the node
        for fanin in self._get_node_fanin(node):
          fanout_of_fanin = self.get_node_fanout(fanin)
          fanout_of_fanin += fanout
          fanout_of_fanin.remove(node)
        # Remove node from the list of important ops since we don't need to
        # place the node.
        if node in self.important_ops:
          self.important_ops.remove(node)
          important_op_names.remove(node.name)

    # List of important op names, in non deterministic order.
    self.important_op_names = frozenset(important_op_names)

  @property
  def input_graph_def(self):
    return self.item.metagraph.graph_def

  @property
  def num_devices(self):
    return len(self.devices)

  def get_node_by_name(self, node_name):
    return self._node[node_name]

  def get_node_fanout(self, node):
    return self._fanout[node.name]

  def get_placements(self, *args, **kwargs):
    """Returns: Two TF ops.

    Args:
      *args: "".
      **kwargs: "".

    Returns:
      y_preds: tensor of size [batch_size, num_ops]
      log_probs: python dict of at least two fields: "sample", "target" each
      containing a tensor of size [batch_size], corresponding to the log_probs.
    """
    raise NotImplementedError

  def eval_placement(self, sess, *args, **kwargs):
    """At this time, this method evaluates ONLY ONE placement.

    Args:
      sess: a tf.Session() object used to retrieve cached assignment info.
      *args: "".
      **kwargs: "".

    Returns:
      run_time: scalar
    """
    raise NotImplementedError

  def export_placement(self, metagraph):
    """Annotate the placement onto the specified metagraph.

    Args:
      metagraph: the metagraph to annotate with the placement.
    """
    for node in metagraph.graph_def.node:
      if node.name in self.important_op_names:
        node.device = self.get_node_by_name(node.name).device

  # Get the nodes in the immediate fanin of node.
  # Beware: this doesn't take into account the nodes that may be skipped
  # since placement constraints force their placement.
  def _get_node_fanin(self, node):
    input_ops = []
    for fanin_name in node.input:
      if fanin_name[0] == "^":
        fanin_name = fanin_name[1:]
      fanin_name = fanin_name.split(":")[0]
      input_ops.append(self.get_node_by_name(fanin_name))
    return input_ops