diff options
author | Rui Zhao <rzhao@google.com> | 2018-03-01 14:15:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-01 14:23:03 -0800 |
commit | 3973e772ed84db08cb86b1086558223af29fd64a (patch) | |
tree | a9fffb193d1482b2d7adbcafd3ffe535fc90e345 /tensorflow/python/grappler | |
parent | f8f4a6e26cc1108495c0b9a55d9a7d6e7005c2b5 (diff) |
Sampling group embeddings for each child.
PiperOrigin-RevId: 187532388
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/hierarchical_controller.py | 41 |
1 files changed, 30 insertions, 11 deletions
diff --git a/tensorflow/python/grappler/hierarchical_controller.py b/tensorflow/python/grappler/hierarchical_controller.py index b06fb3c6d0..c0866c1069 100644 --- a/tensorflow/python/grappler/hierarchical_controller.py +++ b/tensorflow/python/grappler/hierarchical_controller.py @@ -258,9 +258,11 @@ class HierarchicalController(Controller): "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size]) variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1]) seq2seq_input_layer = array_ops.placeholder_with_default( - array_ops.zeros([1, self.num_groups, self.group_emb_size], + array_ops.zeros([self.hparams.num_children, + self.num_groups, + self.group_emb_size], dtypes.float32), - shape=(1, self.num_groups, self.group_emb_size)) + shape=(self.hparams.num_children, self.num_groups, self.group_emb_size)) self.seq2seq_input_layer = seq2seq_input_layer def compute_reward(self, run_time): @@ -585,12 +587,29 @@ class HierarchicalController(Controller): """Approximating the blocks of a TF graph from a graph_def. Args: - grouping_actions: grouping predictions + grouping_actions: grouping predictions. verbose: print stuffs. Returns: groups: list of groups. """ + groups = [ + self._create_group_embeddings(grouping_actions, i, verbose) for + i in range(self.hparams.num_children) + ] + return np.stack(groups, axis=0) + + def _create_group_embeddings(self, grouping_actions, child_id, verbose=False): + """Approximating the blocks of a TF graph from a graph_def for each child. + + Args: + grouping_actions: grouping predictions. + child_id: child_id for the group. + verbose: print stuffs. + + Returns: + groups: group embedding for the child_id. + """ if verbose: print("Processing input_graph") @@ -599,13 +618,13 @@ class HierarchicalController(Controller): dag_matrix = np.zeros([self.num_groups, self.num_groups], dtype=np.float32) for op in self.important_ops: topo_op_index = self.name_to_topo_order_index[op.name] - # TODO(agoldie) child_id - group_index = grouping_actions[0][topo_op_index] + group_index = grouping_actions[child_id][topo_op_index] for output_op in self.get_node_fanout(op): if output_op.name not in self.important_op_names: continue - output_group_index = grouping_actions[0][self.name_to_topo_order_index[ - output_op.name]] + output_group_index = ( + grouping_actions[child_id][self.name_to_topo_order_index[ + output_op.name]]) dag_matrix[group_index, output_group_index] += 1.0 num_connections = np.sum(dag_matrix) num_intra_group_connections = dag_matrix.trace() @@ -648,7 +667,8 @@ class HierarchicalController(Controller): ], dtype=np.float32) for op_index, op in enumerate(self.important_ops): - group_index = grouping_actions[0][self.name_to_topo_order_index[op.name]] + group_index = grouping_actions[child_id][ + self.name_to_topo_order_index[op.name]] type_name = str(op.op) type_index = self.type_dict[type_name] group_embedding[group_index, type_index] += 1 @@ -675,7 +695,7 @@ class HierarchicalController(Controller): shape=[num_children, self.num_groups], trainable=False) - x = array_ops.tile(self.seq2seq_input_layer, [num_children, 1, 1]) + x = self.seq2seq_input_layer last_c, last_h, attn_mem = self.encode(x) actions, log_probs = {}, {} actions["sample"], log_probs["sample"] = ( @@ -988,8 +1008,7 @@ class HierarchicalController(Controller): def generate_placement(self, grouping, sess): controller_ops = self.ops["controller"] feed_seq2seq_input_dict = {} - feed_seq2seq_input_dict[self.seq2seq_input_layer] = np.expand_dims( - grouping, axis=0) + feed_seq2seq_input_dict[self.seq2seq_input_layer] = grouping sess.run( controller_ops["y_preds"]["sample"], feed_dict=feed_seq2seq_input_dict) |