aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Rui Zhao <rzhao@google.com>2018-03-01 14:15:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-01 14:23:03 -0800
commit3973e772ed84db08cb86b1086558223af29fd64a (patch)
treea9fffb193d1482b2d7adbcafd3ffe535fc90e345 /tensorflow/python/grappler
parentf8f4a6e26cc1108495c0b9a55d9a7d6e7005c2b5 (diff)
Sampling group embeddings for each child.
PiperOrigin-RevId: 187532388
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/hierarchical_controller.py41
1 files changed, 30 insertions, 11 deletions
diff --git a/tensorflow/python/grappler/hierarchical_controller.py b/tensorflow/python/grappler/hierarchical_controller.py
index b06fb3c6d0..c0866c1069 100644
--- a/tensorflow/python/grappler/hierarchical_controller.py
+++ b/tensorflow/python/grappler/hierarchical_controller.py
@@ -258,9 +258,11 @@ class HierarchicalController(Controller):
"attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size])
variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1])
seq2seq_input_layer = array_ops.placeholder_with_default(
- array_ops.zeros([1, self.num_groups, self.group_emb_size],
+ array_ops.zeros([self.hparams.num_children,
+ self.num_groups,
+ self.group_emb_size],
dtypes.float32),
- shape=(1, self.num_groups, self.group_emb_size))
+ shape=(self.hparams.num_children, self.num_groups, self.group_emb_size))
self.seq2seq_input_layer = seq2seq_input_layer
def compute_reward(self, run_time):
@@ -585,12 +587,29 @@ class HierarchicalController(Controller):
"""Approximating the blocks of a TF graph from a graph_def.
Args:
- grouping_actions: grouping predictions
+ grouping_actions: grouping predictions.
verbose: print stuffs.
Returns:
groups: list of groups.
"""
+ groups = [
+ self._create_group_embeddings(grouping_actions, i, verbose) for
+ i in range(self.hparams.num_children)
+ ]
+ return np.stack(groups, axis=0)
+
+ def _create_group_embeddings(self, grouping_actions, child_id, verbose=False):
+ """Approximating the blocks of a TF graph from a graph_def for each child.
+
+ Args:
+ grouping_actions: grouping predictions.
+ child_id: child_id for the group.
+ verbose: print stuffs.
+
+ Returns:
+ groups: group embedding for the child_id.
+ """
if verbose:
print("Processing input_graph")
@@ -599,13 +618,13 @@ class HierarchicalController(Controller):
dag_matrix = np.zeros([self.num_groups, self.num_groups], dtype=np.float32)
for op in self.important_ops:
topo_op_index = self.name_to_topo_order_index[op.name]
- # TODO(agoldie) child_id
- group_index = grouping_actions[0][topo_op_index]
+ group_index = grouping_actions[child_id][topo_op_index]
for output_op in self.get_node_fanout(op):
if output_op.name not in self.important_op_names:
continue
- output_group_index = grouping_actions[0][self.name_to_topo_order_index[
- output_op.name]]
+ output_group_index = (
+ grouping_actions[child_id][self.name_to_topo_order_index[
+ output_op.name]])
dag_matrix[group_index, output_group_index] += 1.0
num_connections = np.sum(dag_matrix)
num_intra_group_connections = dag_matrix.trace()
@@ -648,7 +667,8 @@ class HierarchicalController(Controller):
],
dtype=np.float32)
for op_index, op in enumerate(self.important_ops):
- group_index = grouping_actions[0][self.name_to_topo_order_index[op.name]]
+ group_index = grouping_actions[child_id][
+ self.name_to_topo_order_index[op.name]]
type_name = str(op.op)
type_index = self.type_dict[type_name]
group_embedding[group_index, type_index] += 1
@@ -675,7 +695,7 @@ class HierarchicalController(Controller):
shape=[num_children, self.num_groups],
trainable=False)
- x = array_ops.tile(self.seq2seq_input_layer, [num_children, 1, 1])
+ x = self.seq2seq_input_layer
last_c, last_h, attn_mem = self.encode(x)
actions, log_probs = {}, {}
actions["sample"], log_probs["sample"] = (
@@ -988,8 +1008,7 @@ class HierarchicalController(Controller):
def generate_placement(self, grouping, sess):
controller_ops = self.ops["controller"]
feed_seq2seq_input_dict = {}
- feed_seq2seq_input_dict[self.seq2seq_input_layer] = np.expand_dims(
- grouping, axis=0)
+ feed_seq2seq_input_dict[self.seq2seq_input_layer] = grouping
sess.run(
controller_ops["y_preds"]["sample"], feed_dict=feed_seq2seq_input_dict)