diff options
author | Benoit Steiner <bsteiner@google.com> | 2018-02-21 21:05:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-21 21:09:41 -0800 |
commit | b3df3aa4f5842fe3184088ef2fa0bb5d6edc21d5 (patch) | |
tree | 9316cffac13f00535f2cf692e56cb0b0a44a6b2f /tensorflow/python/grappler | |
parent | ddd66709a396644112e3dda165d53fdd485d7de3 (diff) |
Started to open source the RL placer.
PiperOrigin-RevId: 186563773
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/cluster.i | 13 | ||||
-rw-r--r-- | tensorflow/python/grappler/cluster_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/grappler/controller.py | 142 | ||||
-rw-r--r-- | tensorflow/python/grappler/graph_placer.py | 110 | ||||
-rw-r--r-- | tensorflow/python/grappler/graph_placer_test.py | 140 | ||||
-rw-r--r-- | tensorflow/python/grappler/hierarchical_controller.py | 1098 | ||||
-rw-r--r-- | tensorflow/python/grappler/item.i | 16 | ||||
-rw-r--r-- | tensorflow/python/grappler/item_test.py | 2 |
8 files changed, 1514 insertions, 11 deletions
diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i index 8079cb307b..067c8213d4 100644 --- a/tensorflow/python/grappler/cluster.i +++ b/tensorflow/python/grappler/cluster.i @@ -206,7 +206,7 @@ static PyObject* TF_ListDevices(GCluster cluster) { return result; } -static std::vector<string> TF_ListAvailableOps() { +static PyObject* TF_ListAvailableOps() { tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global(); std::vector<tensorflow::OpDef> ops; registry->GetRegisteredOps(&ops); @@ -215,7 +215,14 @@ static std::vector<string> TF_ListAvailableOps() { op_names.push_back(op.name()); } std::sort(op_names.begin(), op_names.end()); - return op_names; + + PyGILState_STATE gstate = PyGILState_Ensure(); + PyObject* result = PyList_New(op_names.size()); + for (int i = 0; i < op_names.size(); ++i) { + PyList_SetItem(result, i, PyString_FromString(op_names[i].c_str())); + } + PyGILState_Release(gstate); + return result; } static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) { @@ -432,7 +439,7 @@ static GCluster TF_NewVirtualCluster( TF_Status* out_status); static void TF_ShutdownCluster(GCluster cluster); static PyObject* TF_ListDevices(GCluster cluster); -static std::vector<string> TF_ListAvailableOps(); +static PyObject* TF_ListAvailableOps(); static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item); static float TF_EstimatePerformance(const tensorflow::NamedDevice& device); static PyObject* TF_MeasureCosts( diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py index caae5b114e..a3c4c2bbeb 100644 --- a/tensorflow/python/grappler/cluster_test.py +++ b/tensorflow/python/grappler/cluster_test.py @@ -131,8 +131,8 @@ class ClusterTest(test.TestCase): def testAvailableOps(self): with cluster.Provision() as gcluster: op_names = gcluster.ListAvailableOps() - self.assertTrue(b'Add' in op_names) - self.assertTrue(b'MatMul' in op_names) + self.assertTrue('Add' in op_names) + self.assertTrue('MatMul' in op_names) self.assertEqual(op_names, sorted(op_names)) def testSupportDevices(self): diff --git a/tensorflow/python/grappler/controller.py b/tensorflow/python/grappler/controller.py new file mode 100644 index 0000000000..5677f4f523 --- /dev/null +++ b/tensorflow/python/grappler/controller.py @@ -0,0 +1,142 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Controller Class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict + + +class Controller(object): + """Controller class.""" + + def __init__(self, item, cluster): + """Controller class initializer. + + Args: + item: The metagraph to place wrapped in a cluster. + cluster: A cluster of devices on which to place the item. + """ + self.item = item + + self._node = {} + for node in item.metagraph.graph_def.node: + self._node[node.name] = node + + self._fanout = defaultdict(lambda: []) + for node in item.metagraph.graph_def.node: + for fanin in self._get_node_fanin(node): + self._fanout[fanin.name].append(node) + + important_op_names = item.IdentifyImportantOps(sort_topologically=True) + + # List of important ops (these are the ops to place) sorted in topological + # order. The order of this collection is deterministic. + self.important_ops = [] + for name in important_op_names: + self.important_ops.append(self._node[name]) + + self.node_properties = item.GetOpProperties() + + self.cluster = cluster + self.devices = cluster.ListDevices() + + self.colocation_constraints = item.GetColocationGroups() + + self.placement_constraints = cluster.GetSupportedDevices(item) + for node_name, dev in self.placement_constraints.items(): + if len(dev) == 1: + # Place the node on the supported device + node = self._node[node_name] + node.device = dev[0] + fanout = self.get_node_fanout(node) + # Update the fanout of the fanin to bypass the node + for fanin in self._get_node_fanin(node): + fanout_of_fanin = self.get_node_fanout(fanin) + fanout_of_fanin += fanout + fanout_of_fanin.remove(node) + # Remove node from the list of important ops since we don't need to + # place the node. + if node in self.important_ops: + self.important_ops.remove(node) + important_op_names.remove(node.name) + + # List of important op names, in non deterministic order. + self.important_op_names = frozenset(important_op_names) + + @property + def input_graph_def(self): + return self.item.metagraph.graph_def + + @property + def num_devices(self): + return len(self.devices) + + def get_node_by_name(self, node_name): + return self._node[node_name] + + def get_node_fanout(self, node): + return self._fanout[node.name] + + def get_placements(self, *args, **kwargs): + """Returns: Two TF ops. + + Args: + *args: "". + **kwargs: "". + + Returns: + y_preds: tensor of size [batch_size, num_ops] + log_probs: python dict of at least two fields: "sample", "target" each + containing a tensor of size [batch_size], corresponding to the log_probs. + """ + raise NotImplementedError + + def eval_placement(self, sess, *args, **kwargs): + """At this time, this method evaluates ONLY ONE placement. + + Args: + sess: a tf.Session() object used to retrieve cached assignment info. + *args: "". + **kwargs: "". + + Returns: + run_time: scalar + """ + raise NotImplementedError + + def export_placement(self, metagraph): + """Annotate the placement onto the specified metagraph. + + Args: + metagraph: the metagraph to annotate with the placement. + """ + for node in metagraph.graph_def.node: + if node.name in self.important_op_names: + node.device = self.get_node_by_name(node.name).device + + # Get the nodes in the immediate fanin of node. + # Beware: this doesn't take into account the nodes that may be skipped + # since placement constraints force their placement. + def _get_node_fanin(self, node): + input_ops = [] + for fanin_name in node.input: + if fanin_name[0] == "^": + fanin_name = fanin_name[1:] + fanin_name = fanin_name.split(":")[0] + input_ops.append(self.get_node_by_name(fanin_name)) + return input_ops diff --git a/tensorflow/python/grappler/graph_placer.py b/tensorflow/python/grappler/graph_placer.py new file mode 100644 index 0000000000..2cc3536792 --- /dev/null +++ b/tensorflow/python/grappler/graph_placer.py @@ -0,0 +1,110 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Graph Placer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.grappler import cluster as gcluster +from tensorflow.python.grappler import hierarchical_controller +from tensorflow.python.grappler import item as gitem +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.training import training + + +def PlaceGraph(metagraph, + cluster=None, + allotted_time=3600, + hparams=None, + verbose=False): + """Place the provided metagraph. + + Args: + metagraph: the metagraph to place. + cluster: an optional set of hardware resource to optimize the placement for. + If none is specified, we'll optimize the placement for the hardware + available on the local machine. + allotted_time: the maximum amount to time in seconds to spend optimizing + the placement. + hparams: hyperparameters used to fine tune the placer. + verbose: prints debug information if True. + + Returns: + The placed metagraph. + """ + if cluster is None: + cluster = gcluster.Cluster() + + # Optimize the metagraph to speedup the placement + rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.optimizers.append("pruning") + rewriter_config.optimizers.append("constfold") + rewriter_config.optimizers.append("arithmetic") + rewriter_config.optimizers.append("dependency") + rewriter_config.optimizers.append("pruning") + optimized_graph = tf_optimizer.OptimizeGraph( + rewriter_config, metagraph, verbose=verbose, cluster=cluster) + optimized_metagraph = meta_graph_pb2.MetaGraphDef() + optimized_metagraph.CopyFrom(metagraph) + optimized_metagraph.graph_def.CopyFrom(optimized_graph) + + item = gitem.Item(optimized_metagraph) + + if hparams is None: + hparams = hierarchical_controller.hierarchical_controller_hparams() + # We run with a single child + hparams.num_children = 1 + + with tf_ops.Graph().as_default(): + # Place all the nodes of the controller on the CPU. We don't want them to + # fight for accelerator memory with the model to optimize. + with tf_ops.device("/device:CPU:0"): + model = hierarchical_controller.HierarchicalController( + hparams, item, cluster) + ops = model.build_controller() + session_creator = training.ChiefSessionCreator() + with training.MonitoredSession(session_creator=session_creator) as sess: + start_time = time.time() + current_time = start_time + while current_time - start_time < allotted_time: + grouping_actions = model.generate_grouping(sess) + input_to_seq2seq = model.create_group_embeddings( + grouping_actions, verbose=verbose) + model.generate_placement(input_to_seq2seq, sess) + try: + run_time = model.eval_placement( + sess, + verbose=verbose) + except errors.OpError as e: + if verbose: + print("Failed to run graph:" + str(e)) + run_time = hparams.failing_signal + updated = model.update_reward(sess, run_time, verbose=verbose) + if updated: + if verbose: + print("Found better placement, with runtime " + str(run_time)) + model.export_placement(metagraph) + + model.process_reward(sess) + + current_time = time.time() + + return metagraph diff --git a/tensorflow/python/grappler/graph_placer_test.py b/tensorflow/python/grappler/graph_placer_test.py new file mode 100644 index 0000000000..9eabe3cd54 --- /dev/null +++ b/tensorflow/python/grappler/graph_placer_test.py @@ -0,0 +1,140 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests the graph placer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.core.protobuf import device_properties_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import meta_graph +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.grappler import cluster +from tensorflow.python.grappler import graph_placer +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class GraphPlacerTest(test.TestCase): + + @staticmethod + def _buildMnist(batch_size=128, + input_size=256, + num_classes=1024, + num_layers=10, + hidden_size=256, + name='mnist'): + g = tf_ops.get_default_graph() + with g.as_default(): + ops = {} + x = random_ops.random_uniform( + [batch_size, input_size], -0.1, 0.1, dtype=dtypes.float32) + for layer_id in range(num_layers): + with variable_scope.variable_scope('layer_{}'.format(layer_id)): + a = input_size if layer_id == 0 else hidden_size + b = hidden_size if layer_id < num_layers - 1 else num_classes + w = variable_scope.get_variable('w', [a, b]) + x = math_ops.matmul(x, w) + x = nn_ops.relu(x) + ops['y_preds'] = math_ops.argmax(x, axis=1) + + train_op = g.get_collection_ref(tf_ops.GraphKeys.TRAIN_OP) + train_op.append(ops['y_preds']) + return g + + @staticmethod + def _buildCluster(num_cpus=1, num_gpus=1): + devices = [] + if num_gpus > 0: + device_properties = device_properties_pb2.DeviceProperties( + type='GPU', + vendor='NVidia', + model='GeForce GTX TITAN X', + frequency=1076, + num_cores=24, + environment={'architecture': '5.2', + 'cuda': '8000', + 'cudnn': '6021'}, + num_registers=65536, + l1_cache_size=24576, + l2_cache_size=3145728, + shared_memory_size_per_multiprocessor=98304, + memory_size=12783648768, + bandwidth=336480000) + for i in range(num_gpus): + devices.append( + device_properties_pb2.NamedDevice( + properties=device_properties, name='/GPU:' + str(i))) + + assert num_cpus > 0 + device_properties = device_properties_pb2.DeviceProperties( + type='CPU', + frequency=2000, + num_cores=4, + l1_cache_size=32768, + l2_cache_size=262144, + l3_cache_size=12582912) + for i in range(num_cpus): + devices.append( + device_properties_pb2.NamedDevice( + properties=device_properties, name='/CPU:' + str(i))) + + return cluster.Cluster(devices=devices) + + def testBasic(self): + """Place a trivial graph.""" + a = constant_op.constant(10, name='a') + b = constant_op.constant(20, name='b') + c = math_ops.add_n([a, b], name='c') + d = math_ops.add_n([b, c], name='d') + train_op = tf_ops.get_collection_ref(tf_ops.GraphKeys.TRAIN_OP) + train_op.append(d) + mg = meta_graph.create_meta_graph_def(graph=tf_ops.get_default_graph()) + + gcluster = cluster.Cluster() + placed_mg = graph_placer.PlaceGraph(mg, allotted_time=15, cluster=gcluster) + + self.assertEqual(4, len(placed_mg.graph_def.node)) + self.assertItemsEqual([node.name for node in placed_mg.graph_def.node], + [node.name for node in mg.graph_def.node]) + + available_devices = [device.name for device in gcluster.ListDevices()] + for node in placed_mg.graph_def.node: + # The constant nodes are optimized away before the placer is run, and + # therefore won't be placed. + self.assertTrue(not node.device or node.device in available_devices) + + def testMNIST(self): + graph = GraphPlacerTest._buildMnist() + mg = meta_graph.create_meta_graph_def(graph=graph) + gcluster = GraphPlacerTest._buildCluster(num_gpus=1) + # Spend 15 seconds trying to optimize the placement of the model. This + # should give us enough time to exercise the code, but not enough to find + # a good placement, so we'll just check for legality. + placed_mg = graph_placer.PlaceGraph(mg, allotted_time=15, cluster=gcluster) + self.assertEqual(len(placed_mg.graph_def.node), len(mg.graph_def.node)) + self.assertItemsEqual([node.name for node in placed_mg.graph_def.node], + [node.name for node in mg.graph_def.node]) + available_devices = [device.name for device in gcluster.ListDevices()] + for node in placed_mg.graph_def.node: + self.assertTrue(not node.device or node.device in available_devices) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/grappler/hierarchical_controller.py b/tensorflow/python/grappler/hierarchical_controller.py new file mode 100644 index 0000000000..655e43e78f --- /dev/null +++ b/tensorflow/python/grappler/hierarchical_controller.py @@ -0,0 +1,1098 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""HierarchicalController Class. + +The HierarchicalController encompasses the entire lifecycle of training the +device placement policy, including generating op embeddings, getting groups for +each op, placing those groups and running the predicted placements. + +Different assignment models can inherit from this class. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import six +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.grappler.controller import Controller +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import adam +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import training_util + + +class PlacerParams(object): + """Class to hold a set of placement parameters as name-value pairs. + + A typical usage is as follows: + + ```python + # Create a PlacerParams object specifying names and values of the model + # parameters: + params = PlacerParams(hidden_size=128, decay_steps=50) + + # The parameters are available as attributes of the PlacerParams object: + hparams.hidden_size ==> 128 + hparams.decay_steps ==> 50 + ``` + + """ + + def __init__(self, **kwargs): + """Create an instance of `PlacerParams` from keyword arguments. + + The keyword arguments specify name-values pairs for the parameters. + The parameter types are inferred from the type of the values passed. + + The parameter names are added as attributes of `PlacerParams` object, + and they can be accessed directly with the dot notation `params._name_`. + + Example: + + ```python + # Define 1 parameter: 'hidden_size' + params = PlacerParams(hidden_size=128) + params.hidden_size ==> 128 + ``` + + Args: + **kwargs: Key-value pairs where the key is the parameter name and + the value is the value for the parameter. + """ + for name, value in six.iteritems(kwargs): + self.add_param(name, value) + + def add_param(self, name, value): + """Adds {name, value} pair to hyperparameters. + + Args: + name: Name of the hyperparameter. + value: Value of the hyperparameter. Can be one of the following types: + int, float, string, int list, float list, or string list. + + Raises: + ValueError: if one of the arguments is invalid. + """ + # Keys in kwargs are unique, but 'name' could be the name of a pre-existing + # attribute of this object. In that case we refuse to use it as a + # parameter name. + if getattr(self, name, None) is not None: + raise ValueError("Parameter name is reserved: %s" % name) + setattr(self, name, value) + + +def hierarchical_controller_hparams(): + """Hyperparameters for hierarchical planner.""" + return PlacerParams( + hidden_size=512, + forget_bias_init=1.0, + temperature=1.0, + logits_std_noise=0.5, + stop_noise_step=750, + decay_steps=50, + max_num_outputs=5, + max_output_size=5, + tanh_constant=1.0, + adj_embed_dim=20, + grouping_hidden_size=64, + num_groups=None, + bi_lstm=True, + failing_signal=100, + stop_sampling=500, + start_with_failing_signal=True, + always_update_baseline=False, + bl_dec=0.9, + grad_bound=1.0, + lr=0.1, + lr_dec=0.95, + start_decay_step=400, + optimizer_type="adam", + stop_updating_after_steps=1000, + name="hierarchical_controller", + keep_prob=1.0, + reward_function="sqrt", + seed=1234, + # distributed training params + num_children=1) + + +class HierarchicalController(Controller): + """HierarchicalController class.""" + + def __init__(self, hparams, item, cluster, controller_id=0): + """HierarchicalController class initializer. + + Args: + hparams: All hyper-parameters. + item: The metagraph to place. + cluster: The cluster of hardware devices to optimize for. + controller_id: the id of the controller in a multi-controller setup. + """ + super(HierarchicalController, self).__init__(item, cluster) + self.ctrl_id = controller_id + self.hparams = hparams + + if self.hparams.num_groups is None: + self.num_groups = min(256, 20 * self.num_devices) + else: + self.num_groups = self.hparams.num_groups + + # creates self.op_embeddings and self.type_dict + self.create_op_embeddings(verbose=False) + # TODO(azalia) clean up embedding/group_embedding_size names + self.group_emb_size = ( + 2 * self.num_groups + len(self.type_dict) + + self.hparams.max_num_outputs * self.hparams.max_output_size) + self.embedding_size = self.group_emb_size + self.initializer = init_ops.glorot_uniform_initializer( + seed=self.hparams.seed) + + with variable_scope.variable_scope( + self.hparams.name, + initializer=self.initializer, + reuse=variable_scope.AUTO_REUSE): + # define parameters of feedforward + variable_scope.get_variable("w_grouping_ff", [ + 1 + self.hparams.max_num_outputs * self.hparams.max_output_size + + self.hparams.adj_embed_dim, self.hparams.grouping_hidden_size + ]) + variable_scope.get_variable( + "w_grouping_softmax", + [self.hparams.grouping_hidden_size, self.num_groups]) + if self.hparams.bi_lstm: + variable_scope.get_variable("encoder_lstm_forward", [ + self.embedding_size + self.hparams.hidden_size / 2, + 2 * self.hparams.hidden_size + ]) + variable_scope.get_variable("encoder_lstm_backward", [ + self.embedding_size + self.hparams.hidden_size / 2, + 2 * self.hparams.hidden_size + ]) + variable_scope.get_variable( + "device_embeddings", [self.num_devices, self.hparams.hidden_size]) + variable_scope.get_variable( + "decoder_lstm", + [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size]) + variable_scope.get_variable( + "device_softmax", [2 * self.hparams.hidden_size, self.num_devices]) + variable_scope.get_variable("device_go_embedding", + [1, self.hparams.hidden_size]) + variable_scope.get_variable( + "encoder_forget_bias", + shape=1, + dtype=dtypes.float32, + initializer=init_ops.constant_initializer( + self.hparams.forget_bias_init)) + variable_scope.get_variable( + "decoder_forget_bias", + shape=1, + dtype=dtypes.float32, + initializer=init_ops.constant_initializer( + self.hparams.forget_bias_init)) + variable_scope.get_variable( + "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size]) + variable_scope.get_variable( + "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size]) + variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1]) + + else: + variable_scope.get_variable("encoder_lstm", [ + self.embedding_size + self.hparams.hidden_size, + 4 * self.hparams.hidden_size + ]) + variable_scope.get_variable( + "device_embeddings", [self.num_devices, self.hparams.hidden_size]) + variable_scope.get_variable( + "decoder_lstm", + [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size]) + variable_scope.get_variable( + "device_softmax", [2 * self.hparams.hidden_size, self.num_devices]) + variable_scope.get_variable("device_go_embedding", + [1, self.hparams.hidden_size]) + variable_scope.get_variable( + "encoder_forget_bias", + shape=1, + dtype=dtypes.float32, + initializer=init_ops.constant_initializer( + self.hparams.forget_bias_init)) + variable_scope.get_variable( + "decoder_forget_bias", + shape=1, + dtype=dtypes.float32, + initializer=init_ops.constant_initializer( + self.hparams.forget_bias_init)) + variable_scope.get_variable( + "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size]) + variable_scope.get_variable( + "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size]) + variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1]) + seq2seq_input_layer = array_ops.placeholder_with_default( + array_ops.zeros([1, self.num_groups, self.group_emb_size], + dtypes.float32), + shape=(1, self.num_groups, self.group_emb_size)) + self.seq2seq_input_layer = seq2seq_input_layer + + def compute_reward(self, run_time): + if self.hparams.reward_function == "id": + reward = run_time + elif self.hparams.reward_function == "sqrt": + reward = math.sqrt(run_time) + elif self.hparams.reward_function == "log": + reward = math.log1p(run_time) + else: + raise NotImplementedError( + "Unrecognized reward function '%s', consider your " + "--reward_function flag value." % self.hparams.reward_function) + return reward + + def build_controller(self): + """RL optimization interface. + + Returns: + ops: A dictionary holding handles of the model used for training. + """ + + self._global_step = training_util.get_or_create_global_step() + ops = {} + ops["loss"] = 0 + + failing_signal = self.compute_reward(self.hparams.failing_signal) + + ctr = {} + + with tf_ops.name_scope("controller_{}".format(self.ctrl_id)): + with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): + ctr["reward"] = {"value": [], "ph": [], "update": []} + ctr["ready"] = {"value": [], "ph": [], "update": []} + ctr["best_reward"] = {"value": [], "update": []} + for i in range(self.hparams.num_children): + reward_value = variable_scope.get_local_variable( + "reward_{}".format(i), + initializer=0.0, + dtype=dtypes.float32, + trainable=False) + reward_ph = array_ops.placeholder( + dtypes.float32, shape=(), name="reward_ph_{}".format(i)) + reward_update = state_ops.assign( + reward_value, reward_ph, use_locking=True) + ctr["reward"]["value"].append(reward_value) + ctr["reward"]["ph"].append(reward_ph) + ctr["reward"]["update"].append(reward_update) + best_reward = variable_scope.get_local_variable( + "best_reward_{}".format(i), + initializer=failing_signal, + dtype=dtypes.float32, + trainable=False) + ctr["best_reward"]["value"].append(best_reward) + ctr["best_reward"]["update"].append( + state_ops.assign(best_reward, + math_ops.minimum(best_reward, reward_update))) + + ready_value = variable_scope.get_local_variable( + "ready_{}".format(i), + initializer=True, + dtype=dtypes.bool, + trainable=False) + ready_ph = array_ops.placeholder( + dtypes.bool, shape=(), name="ready_ph_{}".format(i)) + ready_update = state_ops.assign( + ready_value, ready_ph, use_locking=True) + ctr["ready"]["value"].append(ready_value) + ctr["ready"]["ph"].append(ready_ph) + ctr["ready"]["update"].append(ready_update) + + ctr["grouping_y_preds"], ctr["grouping_log_probs"] = self.get_groupings() + summary.histogram( + "grouping_actions", + array_ops.slice(ctr["grouping_y_preds"]["sample"], [0, 0], + [1, array_ops.shape(self.op_embeddings)[0]])) + + with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): + ctr["baseline"] = variable_scope.get_local_variable( + "baseline", + initializer=failing_signal + if self.hparams.start_with_failing_signal else 0.0, + dtype=dtypes.float32, + trainable=False) + + new_baseline = self.hparams.bl_dec * ctr["baseline"] + ( + 1 - self.hparams.bl_dec) * math_ops.reduce_mean( + ctr["reward"]["value"]) + if not self.hparams.always_update_baseline: + baseline_mask = math_ops.less(ctr["reward"]["value"], failing_signal) + selected_reward = array_ops.boolean_mask(ctr["reward"]["value"], + baseline_mask) + selected_baseline = control_flow_ops.cond( + math_ops.reduce_any(baseline_mask), + lambda: math_ops.reduce_mean(selected_reward), + lambda: constant_op.constant(0, dtype=dtypes.float32)) + ctr["pos_reward"] = selected_baseline + pos_ = math_ops.less( + constant_op.constant(0, dtype=dtypes.float32), selected_baseline) + selected_baseline = self.hparams.bl_dec * ctr["baseline"] + ( + 1 - self.hparams.bl_dec) * selected_baseline + selected_baseline = control_flow_ops.cond( + pos_, lambda: selected_baseline, lambda: ctr["baseline"]) + new_baseline = control_flow_ops.cond( + math_ops.less(self.global_step, + self.hparams.stop_updating_after_steps), + lambda: new_baseline, lambda: selected_baseline) + ctr["baseline_update"] = state_ops.assign( + ctr["baseline"], new_baseline, use_locking=True) + + ctr["y_preds"], ctr["log_probs"] = self.get_placements() + summary.histogram("actions", ctr["y_preds"]["sample"]) + mask = math_ops.less(ctr["reward"]["value"], failing_signal) + ctr["loss"] = ctr["reward"]["value"] - ctr["baseline"] + ctr["loss"] *= ( + ctr["log_probs"]["sample"] + ctr["grouping_log_probs"]["sample"]) + + selected_loss = array_ops.boolean_mask(ctr["loss"], mask) + selected_loss = control_flow_ops.cond( + math_ops.reduce_any(mask), + lambda: math_ops.reduce_mean(-selected_loss), + lambda: constant_op.constant(0, dtype=dtypes.float32)) + + ctr["loss"] = control_flow_ops.cond( + math_ops.less(self.global_step, + self.hparams.stop_updating_after_steps), + lambda: math_ops.reduce_mean(-ctr["loss"]), lambda: selected_loss) + + ctr["reward_s"] = math_ops.reduce_mean(ctr["reward"]["value"]) + summary.scalar("loss", ctr["loss"]) + summary.scalar("avg_reward", ctr["reward_s"]) + summary.scalar("best_reward_so_far", best_reward) + summary.scalar( + "advantage", + math_ops.reduce_mean(ctr["reward"]["value"] - ctr["baseline"])) + + with variable_scope.variable_scope( + "optimizer", reuse=variable_scope.AUTO_REUSE): + (ctr["train_op"], ctr["lr"], ctr["grad_norm"], + ctr["grad_norms"]) = self._get_train_ops( + ctr["loss"], + tf_ops.get_collection(tf_ops.GraphKeys.TRAINABLE_VARIABLES), + self.global_step, + grad_bound=self.hparams.grad_bound, + lr_init=self.hparams.lr, + lr_dec=self.hparams.lr_dec, + start_decay_step=self.hparams.start_decay_step, + decay_steps=self.hparams.decay_steps, + optimizer_type=self.hparams.optimizer_type) + + summary.scalar("gradnorm", ctr["grad_norm"]) + summary.scalar("lr", ctr["lr"]) + ctr["summary"] = summary.merge_all() + ops["controller"] = ctr + + self.ops = ops + return ops + + @property + def global_step(self): + return self._global_step + + def create_op_embeddings(self, verbose=False): + if verbose: + print("process input graph for op embeddings") + self.num_ops = len(self.important_ops) + # topological sort of important nodes + topo_order = [op.name for op in self.important_ops] + + # create index to name for topologicaly sorted important nodes + name_to_topo_order_index = {} + for idx, x in enumerate(topo_order): + name_to_topo_order_index[x] = idx + self.name_to_topo_order_index = name_to_topo_order_index + + # create adj matrix + adj_dict = {} + for idx, op in enumerate(self.important_ops): + for output_op in self.get_node_fanout(op): + output_op_name = output_op.name + if output_op_name in self.important_op_names: + if name_to_topo_order_index[op.name] not in adj_dict: + adj_dict[name_to_topo_order_index[op.name]] = [] + adj_dict[name_to_topo_order_index[op.name]].extend( + [name_to_topo_order_index[output_op_name], 1]) + if output_op_name not in adj_dict: + adj_dict[name_to_topo_order_index[output_op_name]] = [] + adj_dict[name_to_topo_order_index[output_op_name]].extend( + [name_to_topo_order_index[op.name], -1]) + + # get op_type op_output_shape, and adj info + output_embed_dim = (self.hparams.max_num_outputs * + self.hparams.max_output_size) + + # TODO(bsteiner): don't filter based on used ops so that we can generalize + # to models that use other types of ops. + used_ops = set() + for node in self.important_ops: + op_type = str(node.op) + used_ops.add(op_type) + + self.type_dict = {} + for op_type in self.cluster.ListAvailableOps(): + if op_type in used_ops: + self.type_dict[op_type] = len(self.type_dict) + + op_types = np.zeros([self.num_ops], dtype=np.int32) + op_output_shapes = np.full( + [self.num_ops, output_embed_dim], -1.0, dtype=np.float32) + for idx, node in enumerate(self.important_ops): + op_types[idx] = self.type_dict[node.op] + # output shape + op_name = node.name + for i, output_prop in enumerate(self.node_properties[op_name]): + if output_prop.shape.__str__() == "<unknown>": + continue + shape = output_prop.shape + for j, dim in enumerate(shape.dim): + if dim.size >= 0: + if i * self.hparams.max_output_size + j >= output_embed_dim: + break + op_output_shapes[idx, + i * self.hparams.max_output_size + j] = dim.size + # adj for padding + op_adj = np.full( + [self.num_ops, self.hparams.adj_embed_dim], 0, dtype=np.float32) + for idx in adj_dict: + neighbors = adj_dict[int(idx)] + min_dim = min(self.hparams.adj_embed_dim, len(neighbors)) + padding_size = self.hparams.adj_embed_dim - min_dim + neighbors = neighbors[:min_dim] + [0] * padding_size + op_adj[int(idx)] = neighbors + + # op_embedding starts here + op_embeddings = np.zeros( + [ + self.num_ops, + 1 + self.hparams.max_num_outputs * self.hparams.max_output_size + + self.hparams.adj_embed_dim + ], + dtype=np.float32) + for idx, op_name in enumerate(topo_order): + op_embeddings[idx] = np.concatenate( + (np.array([op_types[idx]]), op_output_shapes[idx], op_adj[int(idx)])) + self.op_embeddings = constant_op.constant( + op_embeddings, dtype=dtypes.float32) + if verbose: + print("num_ops = {}".format(self.num_ops)) + print("num_types = {}".format(len(self.type_dict))) + + def get_groupings(self, *args, **kwargs): + num_children = self.hparams.num_children + with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): + grouping_actions_cache = variable_scope.get_local_variable( + "grouping_actions_cache", + initializer=init_ops.zeros_initializer, + dtype=dtypes.int32, + shape=[num_children, self.num_ops], + trainable=False) + input_layer = self.op_embeddings + input_layer = array_ops.expand_dims(input_layer, 0) + feed_ff_input_layer = array_ops.tile(input_layer, [num_children, 1, 1]) + grouping_actions, grouping_log_probs = {}, {} + grouping_actions["sample"], grouping_log_probs[ + "sample"] = self.make_grouping_predictions(feed_ff_input_layer) + + grouping_actions["sample"] = state_ops.assign(grouping_actions_cache, + grouping_actions["sample"]) + self.grouping_actions_cache = grouping_actions_cache + + return grouping_actions, grouping_log_probs + + def make_grouping_predictions(self, input_layer, reuse=None): + """model that predicts grouping (grouping_actions). + + Args: + input_layer: group_input_layer + reuse: reuse + + Returns: + grouping_actions: actions + grouping_log_probs: log probabilities corresponding to actions + """ + with variable_scope.variable_scope(self.hparams.name, reuse=True): + # input_layer: tensor of size [1, num_ops, hidden_size] + w_grouping_ff = variable_scope.get_variable("w_grouping_ff") + w_grouping_softmax = variable_scope.get_variable("w_grouping_softmax") + + batch_size = array_ops.shape(input_layer)[0] + embedding_dim = array_ops.shape(input_layer)[2] + + reshaped = array_ops.reshape(input_layer, + [batch_size * self.num_ops, embedding_dim]) + ff_output = math_ops.matmul(reshaped, w_grouping_ff) + logits = math_ops.matmul(ff_output, w_grouping_softmax) + if self.hparams.logits_std_noise > 0: + num_in_logits = math_ops.cast( + array_ops.size(logits), dtype=dtypes.float32) + avg_norm = math_ops.divide( + linalg_ops.norm(logits), math_ops.sqrt(num_in_logits)) + logits_noise = random_ops.random_normal( + array_ops.shape(logits), + stddev=self.hparams.logits_std_noise * avg_norm) + logits = control_flow_ops.cond( + self.global_step > self.hparams.stop_noise_step, lambda: logits, + lambda: logits + logits_noise) + logits = array_ops.reshape(logits, + [batch_size * self.num_ops, self.num_groups]) + actions = random_ops.multinomial(logits, 1, seed=self.hparams.seed) + actions = math_ops.to_int32(actions) + actions = array_ops.reshape(actions, [batch_size, self.num_ops]) + action_label = array_ops.reshape(actions, [-1]) + log_probs = nn_ops.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=action_label) + log_probs = array_ops.reshape(log_probs, [batch_size, -1]) + log_probs = math_ops.reduce_sum(log_probs, 1) + grouping_actions = actions + grouping_log_probs = log_probs + return grouping_actions, grouping_log_probs + + def create_group_embeddings(self, grouping_actions, verbose=False): + """Approximating the blocks of a TF graph from a graph_def. + + Args: + grouping_actions: grouping predictions + verbose: print stuffs. + + Returns: + groups: list of groups. + """ + if verbose: + print("Processing input_graph") + + # TODO(azalia): Build inter-adjacencies dag matrix. + # record dag_matrix + dag_matrix = np.zeros([self.num_groups, self.num_groups], dtype=np.float32) + for op in self.important_ops: + topo_op_index = self.name_to_topo_order_index[op.name] + # TODO(agoldie) child_id + group_index = grouping_actions[0][topo_op_index] + for output_op in self.get_node_fanout(op): + if output_op.name not in self.important_op_names: + continue + output_group_index = grouping_actions[0][self.name_to_topo_order_index[ + output_op.name]] + dag_matrix[group_index, output_group_index] += 1.0 + num_connections = np.sum(dag_matrix) + num_intra_group_connections = dag_matrix.trace() + num_inter_group_connections = num_connections - num_intra_group_connections + if verbose: + print("grouping evaluation metric") + print("num_connections={} num_intra_group_connections={} " + "num_inter_group_connections={}").format( + num_connections, num_intra_group_connections, + num_inter_group_connections) + self.dag_matrix = dag_matrix + + # output_shape + op_output_shapes = np.zeros( + [ + len(self.important_ops), + self.hparams.max_num_outputs * self.hparams.max_output_size + ], + dtype=np.float32) + + for idx, op in enumerate(self.important_ops): + for i, output_properties in enumerate(self.node_properties[op.name]): + if output_properties.shape.__str__() == "<unknown>": + continue + if i > self.hparams.max_num_outputs: + break + shape = output_properties.shape + for j, dim in enumerate(shape.dim): + if dim.size > 0: + k = i * self.hparams.max_output_size + j + if k >= self.hparams.max_num_outputs * self.hparams.max_output_size: + break + op_output_shapes[idx, k] = dim.size + + # group_embedding + group_embedding = np.zeros( + [ + self.num_groups, len(self.type_dict) + + self.hparams.max_num_outputs * self.hparams.max_output_size + ], + dtype=np.float32) + for op_index, op in enumerate(self.important_ops): + group_index = grouping_actions[0][self.name_to_topo_order_index[op.name]] + type_name = str(op.op) + type_index = self.type_dict[type_name] + group_embedding[group_index, type_index] += 1 + group_embedding[group_index, :self.hparams.max_num_outputs * self.hparams. + max_output_size] += ( + op_output_shapes[op_index]) + grouping_adjacencies = np.concatenate( + [dag_matrix, np.transpose(dag_matrix)], axis=1) + group_embedding = np.concatenate( + [grouping_adjacencies, group_embedding], axis=1) + group_normalizer = np.amax(group_embedding, axis=1, keepdims=True) + group_embedding /= (group_normalizer + 1.0) + if verbose: + print("Finished Processing Input Graph") + return group_embedding + + def get_placements(self, *args, **kwargs): + num_children = self.hparams.num_children + with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): + actions_cache = variable_scope.get_local_variable( + "actions_cache", + initializer=init_ops.zeros_initializer, + dtype=dtypes.int32, + shape=[num_children, self.num_groups], + trainable=False) + + x = array_ops.tile(self.seq2seq_input_layer, [num_children, 1, 1]) + last_c, last_h, attn_mem = self.encode(x) + actions, log_probs = {}, {} + actions["sample"], log_probs["sample"] = ( + self.decode( + x, last_c, last_h, attn_mem, mode="sample")) + actions["target"], log_probs["target"] = ( + self.decode( + x, + last_c, + last_h, + attn_mem, + mode="target", + y=actions_cache)) + actions["greedy"], log_probs["greedy"] = ( + self.decode( + x, last_c, last_h, attn_mem, mode="greedy")) + actions["sample"] = control_flow_ops.cond( + self.global_step < self.hparams.stop_sampling, + lambda: state_ops.assign(actions_cache, actions["sample"]), + lambda: state_ops.assign(actions_cache, actions["target"])) + self.actions_cache = actions_cache + + return actions, log_probs + + def encode(self, x): + """Encoder using LSTM. + + Args: + x: tensor of size [num_children, num_groups, embedding_size] + + Returns: + last_c, last_h: tensors of size [num_children, hidden_size], the final + LSTM states + attn_mem: tensor of size [num_children, num_groups, hidden_size], the + attention + memory, i.e. concatenation of all hidden states, linearly transformed by + an attention matrix attn_w_1 + """ + if self.hparams.bi_lstm: + with variable_scope.variable_scope(self.hparams.name, reuse=True): + w_lstm_forward = variable_scope.get_variable("encoder_lstm_forward") + w_lstm_backward = variable_scope.get_variable("encoder_lstm_backward") + forget_bias = variable_scope.get_variable("encoder_forget_bias") + attn_w_1 = variable_scope.get_variable("attn_w_1") + else: + with variable_scope.variable_scope(self.hparams.name, reuse=True): + w_lstm = variable_scope.get_variable("encoder_lstm") + forget_bias = variable_scope.get_variable("encoder_forget_bias") + attn_w_1 = variable_scope.get_variable("attn_w_1") + + embedding_size = array_ops.shape(x)[2] + + signals = array_ops.split(x, self.num_groups, axis=1) + for i in range(len(signals)): + signals[i] = array_ops.reshape( + signals[i], [self.hparams.num_children, embedding_size]) + + if self.hparams.bi_lstm: + + def body(i, prev_c_forward, prev_h_forward, prev_c_backward, + prev_h_backward): + """while loop for LSTM.""" + signal_forward = signals[i] + next_c_forward, next_h_forward = lstm(signal_forward, prev_c_forward, + prev_h_forward, w_lstm_forward, + forget_bias) + + signal_backward = signals[self.num_groups - 1 - i] + next_c_backward, next_h_backward = lstm( + signal_backward, prev_c_backward, prev_h_backward, w_lstm_backward, + forget_bias) + + next_h = array_ops.concat([next_h_forward, next_h_backward], axis=1) + all_h.append(next_h) + + return (next_c_forward, next_h_forward, next_c_backward, + next_h_backward) + + c_forward = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size / 2], + dtype=dtypes.float32) + h_forward = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size / 2], + dtype=dtypes.float32) + + c_backward = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size / 2], + dtype=dtypes.float32) + h_backward = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size / 2], + dtype=dtypes.float32) + all_h = [] + + for i in range(0, self.num_groups): + c_forward, h_forward, c_backward, h_backward = body( + i, c_forward, h_forward, c_backward, h_backward) + + last_c = array_ops.concat([c_forward, c_backward], axis=1) + last_h = array_ops.concat([h_forward, h_backward], axis=1) + attn_mem = array_ops.stack(all_h) + + else: + + def body(i, prev_c, prev_h): + signal = signals[i] + next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias) + all_h.append(next_h) + return next_c, next_h + + c = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size], + dtype=dtypes.float32) + h = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size], + dtype=dtypes.float32) + all_h = [] + + for i in range(0, self.num_groups): + c, h = body(i, c, h) + + last_c = c + last_h = h + attn_mem = array_ops.stack(all_h) + + attn_mem = array_ops.transpose(attn_mem, [1, 0, 2]) + attn_mem = array_ops.reshape( + attn_mem, + [self.hparams.num_children * self.num_groups, self.hparams.hidden_size]) + attn_mem = math_ops.matmul(attn_mem, attn_w_1) + attn_mem = array_ops.reshape( + attn_mem, + [self.hparams.num_children, self.num_groups, self.hparams.hidden_size]) + + return last_c, last_h, attn_mem + + def decode(self, + x, + last_c, + last_h, + attn_mem, + mode="target", + y=None): + """Decoder using LSTM. + + Args: + x: tensor of size [num_children, num_groups, embedding_size]. + last_c: tensor of size [num_children, hidden_size], the final LSTM states + computed by self.encoder. + last_h: same as last_c. + attn_mem: tensor of size [num_children, num_groups, hidden_size]. + mode: "target" or "sample". + y: tensor of size [num_children, num_groups], the device placements. + + Returns: + actions: tensor of size [num_children, num_groups], the placements of + devices + """ + with variable_scope.variable_scope(self.hparams.name, reuse=True): + w_lstm = variable_scope.get_variable("decoder_lstm") + forget_bias = variable_scope.get_variable("decoder_forget_bias") + device_embeddings = variable_scope.get_variable("device_embeddings") + device_softmax = variable_scope.get_variable("device_softmax") + device_go_embedding = variable_scope.get_variable("device_go_embedding") + attn_w_2 = variable_scope.get_variable("attn_w_2") + attn_v = variable_scope.get_variable("attn_v") + + actions = tensor_array_ops.TensorArray( + dtypes.int32, + size=self.num_groups, + infer_shape=False, + clear_after_read=False) + + # pylint: disable=unused-argument + def condition(i, *args): + return math_ops.less(i, self.num_groups) + + # pylint: disable=missing-docstring + def body(i, prev_c, prev_h, actions, log_probs): + # pylint: disable=g-long-lambda + signal = control_flow_ops.cond( + math_ops.equal(i, 0), + lambda: array_ops.tile(device_go_embedding, + [self.hparams.num_children, 1]), + lambda: embedding_ops.embedding_lookup(device_embeddings, + actions.read(i - 1)) + ) + if self.hparams.keep_prob is not None: + signal = nn_ops.dropout(signal, self.hparams.keep_prob) + next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias) + query = math_ops.matmul(next_h, attn_w_2) + query = array_ops.reshape( + query, [self.hparams.num_children, 1, self.hparams.hidden_size]) + query = math_ops.tanh(query + attn_mem) + query = array_ops.reshape(query, [ + self.hparams.num_children * self.num_groups, self.hparams.hidden_size + ]) + query = math_ops.matmul(query, attn_v) + query = array_ops.reshape(query, + [self.hparams.num_children, self.num_groups]) + query = nn_ops.softmax(query) + query = array_ops.reshape(query, + [self.hparams.num_children, self.num_groups, 1]) + query = math_ops.reduce_sum(attn_mem * query, axis=1) + query = array_ops.concat([next_h, query], axis=1) + logits = math_ops.matmul(query, device_softmax) + logits /= self.hparams.temperature + if self.hparams.tanh_constant > 0: + logits = math_ops.tanh(logits) * self.hparams.tanh_constant + if self.hparams.logits_std_noise > 0: + num_in_logits = math_ops.cast( + array_ops.size(logits), dtype=dtypes.float32) + avg_norm = math_ops.divide( + linalg_ops.norm(logits), math_ops.sqrt(num_in_logits)) + logits_noise = random_ops.random_normal( + array_ops.shape(logits), + stddev=self.hparams.logits_std_noise * avg_norm) + logits = control_flow_ops.cond( + self.global_step > self.hparams.stop_noise_step, lambda: logits, + lambda: logits + logits_noise) + + if mode == "sample": + next_y = random_ops.multinomial(logits, 1, seed=self.hparams.seed) + elif mode == "greedy": + next_y = math_ops.argmax(logits, 1) + elif mode == "target": + next_y = array_ops.slice(y, [0, i], [-1, 1]) + else: + raise NotImplementedError + next_y = math_ops.to_int32(next_y) + next_y = array_ops.reshape(next_y, [self.hparams.num_children]) + actions = actions.write(i, next_y) + log_probs += nn_ops.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=next_y) + return i + 1, next_c, next_h, actions, log_probs + + loop_vars = [ + constant_op.constant(0, dtype=dtypes.int32), last_c, last_h, actions, + array_ops.zeros([self.hparams.num_children], dtype=dtypes.float32) + ] + loop_outputs = control_flow_ops.while_loop(condition, body, loop_vars) + + last_c = loop_outputs[-4] + last_h = loop_outputs[-3] + actions = loop_outputs[-2].stack() + actions = array_ops.transpose(actions, [1, 0]) + log_probs = loop_outputs[-1] + return actions, log_probs + + def eval_placement(self, + sess, + child_id=0, + verbose=False): + grouping_actions, actions = sess.run([ + self.grouping_actions_cache, + self.actions_cache + ]) + grouping_actions = grouping_actions[child_id] + actions = actions[child_id] + if verbose: + global_step = sess.run(self.global_step) + if global_step % 100 == 0: + log_string = "op group assignments: " + for a in grouping_actions: + log_string += "{} ".format(a) + print(log_string[:-1]) + log_string = "group device assignments: " + for a in actions: + log_string += "{} ".format(a) + print(log_string[:-1]) + + for op in self.important_ops: + topo_order_index = self.name_to_topo_order_index[op.name] + group_index = grouping_actions[topo_order_index] + op.device = self.devices[actions[group_index]].name + try: + _, run_time, _ = self.cluster.MeasureCosts(self.item) + except errors.ResourceExhaustedError: + run_time = self.hparams.failing_signal + return run_time + + def update_reward(self, + sess, + run_time, + child_id=0, + verbose=False): + reward = self.compute_reward(run_time) + controller_ops = self.ops["controller"] + _, best_reward = sess.run( + [ + controller_ops["reward"]["update"][child_id], + controller_ops["best_reward"]["update"][child_id] + ], + feed_dict={ + controller_ops["reward"]["ph"][child_id]: reward, + }) + if verbose: + print("run_time={:<.5f} reward={:<.5f} " + "best_reward={:<.5f}").format(run_time, reward, best_reward) + + # Reward is a double, best_reward a float: allow for some slack in the + # comparison. + updated = abs(best_reward - reward) < 1e-6 + return updated + + def generate_grouping(self, sess): + controller_ops = self.ops["controller"] + grouping_actions = sess.run(controller_ops["grouping_y_preds"]["sample"]) + return grouping_actions + + def generate_placement(self, grouping, sess): + controller_ops = self.ops["controller"] + feed_seq2seq_input_dict = {} + feed_seq2seq_input_dict[self.seq2seq_input_layer] = np.expand_dims( + grouping, axis=0) + sess.run( + controller_ops["y_preds"]["sample"], feed_dict=feed_seq2seq_input_dict) + + def process_reward(self, sess): + controller_ops = self.ops["controller"] + run_ops = [ + controller_ops["loss"], controller_ops["lr"], + controller_ops["grad_norm"], controller_ops["grad_norms"], + controller_ops["train_op"] + ] + sess.run(run_ops) + sess.run(controller_ops["baseline_update"]) + + def _get_train_ops(self, + loss, + tf_variables, + global_step, + grad_bound=1.25, + lr_init=1e-3, + lr_dec=0.9, + start_decay_step=10000, + decay_steps=100, + optimizer_type="adam"): + """Loss optimizer. + + Args: + loss: scalar tf tensor + tf_variables: list of training variables, typically + tf.trainable_variables() + global_step: global_step + grad_bound: max gradient norm + lr_init: initial learning rate + lr_dec: leaning rate decay coefficient + start_decay_step: start decaying learning rate after this many steps + decay_steps: apply decay rate factor at this step intervals + optimizer_type: optimizer type should be either adam or sgd + + Returns: + train_op: training op + learning_rate: scalar learning rate tensor + grad_norm: l2 norm of the gradient vector + all_grad_norms: l2 norm of each component + """ + lr_gstep = global_step - start_decay_step + + def f1(): + return constant_op.constant(lr_init) + + def f2(): + return learning_rate_decay.exponential_decay(lr_init, lr_gstep, + decay_steps, lr_dec, True) + + learning_rate = control_flow_ops.cond( + math_ops.less(global_step, start_decay_step), + f1, + f2, + name="learning_rate") + + if optimizer_type == "adam": + opt = adam.AdamOptimizer(learning_rate) + elif optimizer_type == "sgd": + opt = gradient_descent.GradientDescentOptimizer(learning_rate) + grads_and_vars = opt.compute_gradients(loss, tf_variables) + grad_norm = clip_ops.global_norm([g for g, v in grads_and_vars]) + all_grad_norms = {} + clipped_grads = [] + clipped_rate = math_ops.maximum(grad_norm / grad_bound, 1.0) + for g, v in grads_and_vars: + if g is not None: + if isinstance(g, tf_ops.IndexedSlices): + clipped = g.values / clipped_rate + norm_square = math_ops.reduce_sum(clipped * clipped) + clipped = tf_ops.IndexedSlices(clipped, g.indices) + else: + clipped = g / clipped_rate + norm_square = math_ops.reduce_sum(clipped * clipped) + all_grad_norms[v.name] = math_ops.sqrt(norm_square) + clipped_grads.append((clipped, v)) + + train_op = opt.apply_gradients(clipped_grads, global_step) + return train_op, learning_rate, grad_norm, all_grad_norms + + +def lstm(x, prev_c, prev_h, w_lstm, forget_bias): + """LSTM cell. + + Args: + x: tensors of size [num_children, hidden_size]. + prev_c: tensors of size [num_children, hidden_size]. + prev_h: same as prev_c. + w_lstm: . + forget_bias: . + + Returns: + next_c: + next_h: + """ + ifog = math_ops.matmul(array_ops.concat([x, prev_h], axis=1), w_lstm) + i, f, o, g = array_ops.split(ifog, 4, axis=1) + i = math_ops.sigmoid(i) + f = math_ops.sigmoid(f + forget_bias) + o = math_ops.sigmoid(o) + g = math_ops.tanh(g) + next_c = i * g + f * prev_c + next_h = o * math_ops.tanh(next_c) + return next_c, next_h diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i index d0fc1a04f2..9a84c60b04 100644 --- a/tensorflow/python/grappler/item.i +++ b/tensorflow/python/grappler/item.i @@ -96,10 +96,10 @@ static GItem TF_NewItem( return GItem(item.release()); } -static std::vector<string> TF_IdentifyImportantOps(GItem item, bool sort_topologically, +static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically, TF_Status* status) { if (item.is_none()) { - return {}; + Py_RETURN_NONE; } std::vector<const tensorflow::NodeDef*> main_ops = item->MainOpsFanin(); @@ -132,7 +132,13 @@ static std::vector<string> TF_IdentifyImportantOps(GItem item, bool sort_topolog } } - return ops; + PyGILState_STATE gstate = PyGILState_Ensure(); + PyObject* result = PyList_New(ops.size()); + for (int i = 0; i < ops.size(); ++i) { + PyList_SetItem(result, i, PyString_FromString(ops[i].c_str())); + } + PyGILState_Release(gstate); + return result; } static PyObject* TF_GetOpProperties(GItem item) { @@ -305,7 +311,7 @@ static PyObject* TF_GetColocationGroups(GItem item) { static GItem TF_NewItem( const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation, bool ignore_user_placement, TF_Status* out_status); -static std::vector<string> TF_IdentifyImportantOps(GItem item, bool sort_topologically, - TF_Status* status); +static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically, + TF_Status* status); static PyObject* TF_GetOpProperties(GItem item); static PyObject* TF_GetColocationGroups(GItem item); diff --git a/tensorflow/python/grappler/item_test.py b/tensorflow/python/grappler/item_test.py index cd70e2fdec..7c3efd6249 100644 --- a/tensorflow/python/grappler/item_test.py +++ b/tensorflow/python/grappler/item_test.py @@ -56,7 +56,7 @@ class ItemTest(test.TestCase): mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_list = grappler_item.IdentifyImportantOps() - self.assertItemsEqual([b'Const', b'Const_1', b'add'], op_list) + self.assertItemsEqual(['Const', 'Const_1', 'add'], op_list) def testOpProperties(self): with ops.Graph().as_default() as g: |