aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-02-21 21:05:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-21 21:09:41 -0800
commitb3df3aa4f5842fe3184088ef2fa0bb5d6edc21d5 (patch)
tree9316cffac13f00535f2cf692e56cb0b0a44a6b2f /tensorflow/python/grappler
parentddd66709a396644112e3dda165d53fdd485d7de3 (diff)
Started to open source the RL placer.
PiperOrigin-RevId: 186563773
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/cluster.i13
-rw-r--r--tensorflow/python/grappler/cluster_test.py4
-rw-r--r--tensorflow/python/grappler/controller.py142
-rw-r--r--tensorflow/python/grappler/graph_placer.py110
-rw-r--r--tensorflow/python/grappler/graph_placer_test.py140
-rw-r--r--tensorflow/python/grappler/hierarchical_controller.py1098
-rw-r--r--tensorflow/python/grappler/item.i16
-rw-r--r--tensorflow/python/grappler/item_test.py2
8 files changed, 1514 insertions, 11 deletions
diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i
index 8079cb307b..067c8213d4 100644
--- a/tensorflow/python/grappler/cluster.i
+++ b/tensorflow/python/grappler/cluster.i
@@ -206,7 +206,7 @@ static PyObject* TF_ListDevices(GCluster cluster) {
return result;
}
-static std::vector<string> TF_ListAvailableOps() {
+static PyObject* TF_ListAvailableOps() {
tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
std::vector<tensorflow::OpDef> ops;
registry->GetRegisteredOps(&ops);
@@ -215,7 +215,14 @@ static std::vector<string> TF_ListAvailableOps() {
op_names.push_back(op.name());
}
std::sort(op_names.begin(), op_names.end());
- return op_names;
+
+ PyGILState_STATE gstate = PyGILState_Ensure();
+ PyObject* result = PyList_New(op_names.size());
+ for (int i = 0; i < op_names.size(); ++i) {
+ PyList_SetItem(result, i, PyString_FromString(op_names[i].c_str()));
+ }
+ PyGILState_Release(gstate);
+ return result;
}
static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) {
@@ -432,7 +439,7 @@ static GCluster TF_NewVirtualCluster(
TF_Status* out_status);
static void TF_ShutdownCluster(GCluster cluster);
static PyObject* TF_ListDevices(GCluster cluster);
-static std::vector<string> TF_ListAvailableOps();
+static PyObject* TF_ListAvailableOps();
static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item);
static float TF_EstimatePerformance(const tensorflow::NamedDevice& device);
static PyObject* TF_MeasureCosts(
diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py
index caae5b114e..a3c4c2bbeb 100644
--- a/tensorflow/python/grappler/cluster_test.py
+++ b/tensorflow/python/grappler/cluster_test.py
@@ -131,8 +131,8 @@ class ClusterTest(test.TestCase):
def testAvailableOps(self):
with cluster.Provision() as gcluster:
op_names = gcluster.ListAvailableOps()
- self.assertTrue(b'Add' in op_names)
- self.assertTrue(b'MatMul' in op_names)
+ self.assertTrue('Add' in op_names)
+ self.assertTrue('MatMul' in op_names)
self.assertEqual(op_names, sorted(op_names))
def testSupportDevices(self):
diff --git a/tensorflow/python/grappler/controller.py b/tensorflow/python/grappler/controller.py
new file mode 100644
index 0000000000..5677f4f523
--- /dev/null
+++ b/tensorflow/python/grappler/controller.py
@@ -0,0 +1,142 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Controller Class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import defaultdict
+
+
+class Controller(object):
+ """Controller class."""
+
+ def __init__(self, item, cluster):
+ """Controller class initializer.
+
+ Args:
+ item: The metagraph to place wrapped in a cluster.
+ cluster: A cluster of devices on which to place the item.
+ """
+ self.item = item
+
+ self._node = {}
+ for node in item.metagraph.graph_def.node:
+ self._node[node.name] = node
+
+ self._fanout = defaultdict(lambda: [])
+ for node in item.metagraph.graph_def.node:
+ for fanin in self._get_node_fanin(node):
+ self._fanout[fanin.name].append(node)
+
+ important_op_names = item.IdentifyImportantOps(sort_topologically=True)
+
+ # List of important ops (these are the ops to place) sorted in topological
+ # order. The order of this collection is deterministic.
+ self.important_ops = []
+ for name in important_op_names:
+ self.important_ops.append(self._node[name])
+
+ self.node_properties = item.GetOpProperties()
+
+ self.cluster = cluster
+ self.devices = cluster.ListDevices()
+
+ self.colocation_constraints = item.GetColocationGroups()
+
+ self.placement_constraints = cluster.GetSupportedDevices(item)
+ for node_name, dev in self.placement_constraints.items():
+ if len(dev) == 1:
+ # Place the node on the supported device
+ node = self._node[node_name]
+ node.device = dev[0]
+ fanout = self.get_node_fanout(node)
+ # Update the fanout of the fanin to bypass the node
+ for fanin in self._get_node_fanin(node):
+ fanout_of_fanin = self.get_node_fanout(fanin)
+ fanout_of_fanin += fanout
+ fanout_of_fanin.remove(node)
+ # Remove node from the list of important ops since we don't need to
+ # place the node.
+ if node in self.important_ops:
+ self.important_ops.remove(node)
+ important_op_names.remove(node.name)
+
+ # List of important op names, in non deterministic order.
+ self.important_op_names = frozenset(important_op_names)
+
+ @property
+ def input_graph_def(self):
+ return self.item.metagraph.graph_def
+
+ @property
+ def num_devices(self):
+ return len(self.devices)
+
+ def get_node_by_name(self, node_name):
+ return self._node[node_name]
+
+ def get_node_fanout(self, node):
+ return self._fanout[node.name]
+
+ def get_placements(self, *args, **kwargs):
+ """Returns: Two TF ops.
+
+ Args:
+ *args: "".
+ **kwargs: "".
+
+ Returns:
+ y_preds: tensor of size [batch_size, num_ops]
+ log_probs: python dict of at least two fields: "sample", "target" each
+ containing a tensor of size [batch_size], corresponding to the log_probs.
+ """
+ raise NotImplementedError
+
+ def eval_placement(self, sess, *args, **kwargs):
+ """At this time, this method evaluates ONLY ONE placement.
+
+ Args:
+ sess: a tf.Session() object used to retrieve cached assignment info.
+ *args: "".
+ **kwargs: "".
+
+ Returns:
+ run_time: scalar
+ """
+ raise NotImplementedError
+
+ def export_placement(self, metagraph):
+ """Annotate the placement onto the specified metagraph.
+
+ Args:
+ metagraph: the metagraph to annotate with the placement.
+ """
+ for node in metagraph.graph_def.node:
+ if node.name in self.important_op_names:
+ node.device = self.get_node_by_name(node.name).device
+
+ # Get the nodes in the immediate fanin of node.
+ # Beware: this doesn't take into account the nodes that may be skipped
+ # since placement constraints force their placement.
+ def _get_node_fanin(self, node):
+ input_ops = []
+ for fanin_name in node.input:
+ if fanin_name[0] == "^":
+ fanin_name = fanin_name[1:]
+ fanin_name = fanin_name.split(":")[0]
+ input_ops.append(self.get_node_by_name(fanin_name))
+ return input_ops
diff --git a/tensorflow/python/grappler/graph_placer.py b/tensorflow/python/grappler/graph_placer.py
new file mode 100644
index 0000000000..2cc3536792
--- /dev/null
+++ b/tensorflow/python/grappler/graph_placer.py
@@ -0,0 +1,110 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Graph Placer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.grappler import cluster as gcluster
+from tensorflow.python.grappler import hierarchical_controller
+from tensorflow.python.grappler import item as gitem
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.training import training
+
+
+def PlaceGraph(metagraph,
+ cluster=None,
+ allotted_time=3600,
+ hparams=None,
+ verbose=False):
+ """Place the provided metagraph.
+
+ Args:
+ metagraph: the metagraph to place.
+ cluster: an optional set of hardware resource to optimize the placement for.
+ If none is specified, we'll optimize the placement for the hardware
+ available on the local machine.
+ allotted_time: the maximum amount to time in seconds to spend optimizing
+ the placement.
+ hparams: hyperparameters used to fine tune the placer.
+ verbose: prints debug information if True.
+
+ Returns:
+ The placed metagraph.
+ """
+ if cluster is None:
+ cluster = gcluster.Cluster()
+
+ # Optimize the metagraph to speedup the placement
+ rewriter_config = rewriter_config_pb2.RewriterConfig()
+ rewriter_config.optimizers.append("pruning")
+ rewriter_config.optimizers.append("constfold")
+ rewriter_config.optimizers.append("arithmetic")
+ rewriter_config.optimizers.append("dependency")
+ rewriter_config.optimizers.append("pruning")
+ optimized_graph = tf_optimizer.OptimizeGraph(
+ rewriter_config, metagraph, verbose=verbose, cluster=cluster)
+ optimized_metagraph = meta_graph_pb2.MetaGraphDef()
+ optimized_metagraph.CopyFrom(metagraph)
+ optimized_metagraph.graph_def.CopyFrom(optimized_graph)
+
+ item = gitem.Item(optimized_metagraph)
+
+ if hparams is None:
+ hparams = hierarchical_controller.hierarchical_controller_hparams()
+ # We run with a single child
+ hparams.num_children = 1
+
+ with tf_ops.Graph().as_default():
+ # Place all the nodes of the controller on the CPU. We don't want them to
+ # fight for accelerator memory with the model to optimize.
+ with tf_ops.device("/device:CPU:0"):
+ model = hierarchical_controller.HierarchicalController(
+ hparams, item, cluster)
+ ops = model.build_controller()
+ session_creator = training.ChiefSessionCreator()
+ with training.MonitoredSession(session_creator=session_creator) as sess:
+ start_time = time.time()
+ current_time = start_time
+ while current_time - start_time < allotted_time:
+ grouping_actions = model.generate_grouping(sess)
+ input_to_seq2seq = model.create_group_embeddings(
+ grouping_actions, verbose=verbose)
+ model.generate_placement(input_to_seq2seq, sess)
+ try:
+ run_time = model.eval_placement(
+ sess,
+ verbose=verbose)
+ except errors.OpError as e:
+ if verbose:
+ print("Failed to run graph:" + str(e))
+ run_time = hparams.failing_signal
+ updated = model.update_reward(sess, run_time, verbose=verbose)
+ if updated:
+ if verbose:
+ print("Found better placement, with runtime " + str(run_time))
+ model.export_placement(metagraph)
+
+ model.process_reward(sess)
+
+ current_time = time.time()
+
+ return metagraph
diff --git a/tensorflow/python/grappler/graph_placer_test.py b/tensorflow/python/grappler/graph_placer_test.py
new file mode 100644
index 0000000000..9eabe3cd54
--- /dev/null
+++ b/tensorflow/python/grappler/graph_placer_test.py
@@ -0,0 +1,140 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests the graph placer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from tensorflow.core.protobuf import device_properties_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.grappler import cluster
+from tensorflow.python.grappler import graph_placer
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class GraphPlacerTest(test.TestCase):
+
+ @staticmethod
+ def _buildMnist(batch_size=128,
+ input_size=256,
+ num_classes=1024,
+ num_layers=10,
+ hidden_size=256,
+ name='mnist'):
+ g = tf_ops.get_default_graph()
+ with g.as_default():
+ ops = {}
+ x = random_ops.random_uniform(
+ [batch_size, input_size], -0.1, 0.1, dtype=dtypes.float32)
+ for layer_id in range(num_layers):
+ with variable_scope.variable_scope('layer_{}'.format(layer_id)):
+ a = input_size if layer_id == 0 else hidden_size
+ b = hidden_size if layer_id < num_layers - 1 else num_classes
+ w = variable_scope.get_variable('w', [a, b])
+ x = math_ops.matmul(x, w)
+ x = nn_ops.relu(x)
+ ops['y_preds'] = math_ops.argmax(x, axis=1)
+
+ train_op = g.get_collection_ref(tf_ops.GraphKeys.TRAIN_OP)
+ train_op.append(ops['y_preds'])
+ return g
+
+ @staticmethod
+ def _buildCluster(num_cpus=1, num_gpus=1):
+ devices = []
+ if num_gpus > 0:
+ device_properties = device_properties_pb2.DeviceProperties(
+ type='GPU',
+ vendor='NVidia',
+ model='GeForce GTX TITAN X',
+ frequency=1076,
+ num_cores=24,
+ environment={'architecture': '5.2',
+ 'cuda': '8000',
+ 'cudnn': '6021'},
+ num_registers=65536,
+ l1_cache_size=24576,
+ l2_cache_size=3145728,
+ shared_memory_size_per_multiprocessor=98304,
+ memory_size=12783648768,
+ bandwidth=336480000)
+ for i in range(num_gpus):
+ devices.append(
+ device_properties_pb2.NamedDevice(
+ properties=device_properties, name='/GPU:' + str(i)))
+
+ assert num_cpus > 0
+ device_properties = device_properties_pb2.DeviceProperties(
+ type='CPU',
+ frequency=2000,
+ num_cores=4,
+ l1_cache_size=32768,
+ l2_cache_size=262144,
+ l3_cache_size=12582912)
+ for i in range(num_cpus):
+ devices.append(
+ device_properties_pb2.NamedDevice(
+ properties=device_properties, name='/CPU:' + str(i)))
+
+ return cluster.Cluster(devices=devices)
+
+ def testBasic(self):
+ """Place a trivial graph."""
+ a = constant_op.constant(10, name='a')
+ b = constant_op.constant(20, name='b')
+ c = math_ops.add_n([a, b], name='c')
+ d = math_ops.add_n([b, c], name='d')
+ train_op = tf_ops.get_collection_ref(tf_ops.GraphKeys.TRAIN_OP)
+ train_op.append(d)
+ mg = meta_graph.create_meta_graph_def(graph=tf_ops.get_default_graph())
+
+ gcluster = cluster.Cluster()
+ placed_mg = graph_placer.PlaceGraph(mg, allotted_time=15, cluster=gcluster)
+
+ self.assertEqual(4, len(placed_mg.graph_def.node))
+ self.assertItemsEqual([node.name for node in placed_mg.graph_def.node],
+ [node.name for node in mg.graph_def.node])
+
+ available_devices = [device.name for device in gcluster.ListDevices()]
+ for node in placed_mg.graph_def.node:
+ # The constant nodes are optimized away before the placer is run, and
+ # therefore won't be placed.
+ self.assertTrue(not node.device or node.device in available_devices)
+
+ def testMNIST(self):
+ graph = GraphPlacerTest._buildMnist()
+ mg = meta_graph.create_meta_graph_def(graph=graph)
+ gcluster = GraphPlacerTest._buildCluster(num_gpus=1)
+ # Spend 15 seconds trying to optimize the placement of the model. This
+ # should give us enough time to exercise the code, but not enough to find
+ # a good placement, so we'll just check for legality.
+ placed_mg = graph_placer.PlaceGraph(mg, allotted_time=15, cluster=gcluster)
+ self.assertEqual(len(placed_mg.graph_def.node), len(mg.graph_def.node))
+ self.assertItemsEqual([node.name for node in placed_mg.graph_def.node],
+ [node.name for node in mg.graph_def.node])
+ available_devices = [device.name for device in gcluster.ListDevices()]
+ for node in placed_mg.graph_def.node:
+ self.assertTrue(not node.device or node.device in available_devices)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/grappler/hierarchical_controller.py b/tensorflow/python/grappler/hierarchical_controller.py
new file mode 100644
index 0000000000..655e43e78f
--- /dev/null
+++ b/tensorflow/python/grappler/hierarchical_controller.py
@@ -0,0 +1,1098 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""HierarchicalController Class.
+
+The HierarchicalController encompasses the entire lifecycle of training the
+device placement policy, including generating op embeddings, getting groups for
+each op, placing those groups and running the predicted placements.
+
+Different assignment models can inherit from this class.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+import six
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.grappler.controller import Controller
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.summary import summary
+from tensorflow.python.training import adam
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import learning_rate_decay
+from tensorflow.python.training import training_util
+
+
+class PlacerParams(object):
+ """Class to hold a set of placement parameters as name-value pairs.
+
+ A typical usage is as follows:
+
+ ```python
+ # Create a PlacerParams object specifying names and values of the model
+ # parameters:
+ params = PlacerParams(hidden_size=128, decay_steps=50)
+
+ # The parameters are available as attributes of the PlacerParams object:
+ hparams.hidden_size ==> 128
+ hparams.decay_steps ==> 50
+ ```
+
+ """
+
+ def __init__(self, **kwargs):
+ """Create an instance of `PlacerParams` from keyword arguments.
+
+ The keyword arguments specify name-values pairs for the parameters.
+ The parameter types are inferred from the type of the values passed.
+
+ The parameter names are added as attributes of `PlacerParams` object,
+ and they can be accessed directly with the dot notation `params._name_`.
+
+ Example:
+
+ ```python
+ # Define 1 parameter: 'hidden_size'
+ params = PlacerParams(hidden_size=128)
+ params.hidden_size ==> 128
+ ```
+
+ Args:
+ **kwargs: Key-value pairs where the key is the parameter name and
+ the value is the value for the parameter.
+ """
+ for name, value in six.iteritems(kwargs):
+ self.add_param(name, value)
+
+ def add_param(self, name, value):
+ """Adds {name, value} pair to hyperparameters.
+
+ Args:
+ name: Name of the hyperparameter.
+ value: Value of the hyperparameter. Can be one of the following types:
+ int, float, string, int list, float list, or string list.
+
+ Raises:
+ ValueError: if one of the arguments is invalid.
+ """
+ # Keys in kwargs are unique, but 'name' could be the name of a pre-existing
+ # attribute of this object. In that case we refuse to use it as a
+ # parameter name.
+ if getattr(self, name, None) is not None:
+ raise ValueError("Parameter name is reserved: %s" % name)
+ setattr(self, name, value)
+
+
+def hierarchical_controller_hparams():
+ """Hyperparameters for hierarchical planner."""
+ return PlacerParams(
+ hidden_size=512,
+ forget_bias_init=1.0,
+ temperature=1.0,
+ logits_std_noise=0.5,
+ stop_noise_step=750,
+ decay_steps=50,
+ max_num_outputs=5,
+ max_output_size=5,
+ tanh_constant=1.0,
+ adj_embed_dim=20,
+ grouping_hidden_size=64,
+ num_groups=None,
+ bi_lstm=True,
+ failing_signal=100,
+ stop_sampling=500,
+ start_with_failing_signal=True,
+ always_update_baseline=False,
+ bl_dec=0.9,
+ grad_bound=1.0,
+ lr=0.1,
+ lr_dec=0.95,
+ start_decay_step=400,
+ optimizer_type="adam",
+ stop_updating_after_steps=1000,
+ name="hierarchical_controller",
+ keep_prob=1.0,
+ reward_function="sqrt",
+ seed=1234,
+ # distributed training params
+ num_children=1)
+
+
+class HierarchicalController(Controller):
+ """HierarchicalController class."""
+
+ def __init__(self, hparams, item, cluster, controller_id=0):
+ """HierarchicalController class initializer.
+
+ Args:
+ hparams: All hyper-parameters.
+ item: The metagraph to place.
+ cluster: The cluster of hardware devices to optimize for.
+ controller_id: the id of the controller in a multi-controller setup.
+ """
+ super(HierarchicalController, self).__init__(item, cluster)
+ self.ctrl_id = controller_id
+ self.hparams = hparams
+
+ if self.hparams.num_groups is None:
+ self.num_groups = min(256, 20 * self.num_devices)
+ else:
+ self.num_groups = self.hparams.num_groups
+
+ # creates self.op_embeddings and self.type_dict
+ self.create_op_embeddings(verbose=False)
+ # TODO(azalia) clean up embedding/group_embedding_size names
+ self.group_emb_size = (
+ 2 * self.num_groups + len(self.type_dict) +
+ self.hparams.max_num_outputs * self.hparams.max_output_size)
+ self.embedding_size = self.group_emb_size
+ self.initializer = init_ops.glorot_uniform_initializer(
+ seed=self.hparams.seed)
+
+ with variable_scope.variable_scope(
+ self.hparams.name,
+ initializer=self.initializer,
+ reuse=variable_scope.AUTO_REUSE):
+ # define parameters of feedforward
+ variable_scope.get_variable("w_grouping_ff", [
+ 1 + self.hparams.max_num_outputs * self.hparams.max_output_size +
+ self.hparams.adj_embed_dim, self.hparams.grouping_hidden_size
+ ])
+ variable_scope.get_variable(
+ "w_grouping_softmax",
+ [self.hparams.grouping_hidden_size, self.num_groups])
+ if self.hparams.bi_lstm:
+ variable_scope.get_variable("encoder_lstm_forward", [
+ self.embedding_size + self.hparams.hidden_size / 2,
+ 2 * self.hparams.hidden_size
+ ])
+ variable_scope.get_variable("encoder_lstm_backward", [
+ self.embedding_size + self.hparams.hidden_size / 2,
+ 2 * self.hparams.hidden_size
+ ])
+ variable_scope.get_variable(
+ "device_embeddings", [self.num_devices, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "decoder_lstm",
+ [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "device_softmax", [2 * self.hparams.hidden_size, self.num_devices])
+ variable_scope.get_variable("device_go_embedding",
+ [1, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "encoder_forget_bias",
+ shape=1,
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(
+ self.hparams.forget_bias_init))
+ variable_scope.get_variable(
+ "decoder_forget_bias",
+ shape=1,
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(
+ self.hparams.forget_bias_init))
+ variable_scope.get_variable(
+ "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size])
+ variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1])
+
+ else:
+ variable_scope.get_variable("encoder_lstm", [
+ self.embedding_size + self.hparams.hidden_size,
+ 4 * self.hparams.hidden_size
+ ])
+ variable_scope.get_variable(
+ "device_embeddings", [self.num_devices, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "decoder_lstm",
+ [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "device_softmax", [2 * self.hparams.hidden_size, self.num_devices])
+ variable_scope.get_variable("device_go_embedding",
+ [1, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "encoder_forget_bias",
+ shape=1,
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(
+ self.hparams.forget_bias_init))
+ variable_scope.get_variable(
+ "decoder_forget_bias",
+ shape=1,
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(
+ self.hparams.forget_bias_init))
+ variable_scope.get_variable(
+ "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size])
+ variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1])
+ seq2seq_input_layer = array_ops.placeholder_with_default(
+ array_ops.zeros([1, self.num_groups, self.group_emb_size],
+ dtypes.float32),
+ shape=(1, self.num_groups, self.group_emb_size))
+ self.seq2seq_input_layer = seq2seq_input_layer
+
+ def compute_reward(self, run_time):
+ if self.hparams.reward_function == "id":
+ reward = run_time
+ elif self.hparams.reward_function == "sqrt":
+ reward = math.sqrt(run_time)
+ elif self.hparams.reward_function == "log":
+ reward = math.log1p(run_time)
+ else:
+ raise NotImplementedError(
+ "Unrecognized reward function '%s', consider your "
+ "--reward_function flag value." % self.hparams.reward_function)
+ return reward
+
+ def build_controller(self):
+ """RL optimization interface.
+
+ Returns:
+ ops: A dictionary holding handles of the model used for training.
+ """
+
+ self._global_step = training_util.get_or_create_global_step()
+ ops = {}
+ ops["loss"] = 0
+
+ failing_signal = self.compute_reward(self.hparams.failing_signal)
+
+ ctr = {}
+
+ with tf_ops.name_scope("controller_{}".format(self.ctrl_id)):
+ with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+ ctr["reward"] = {"value": [], "ph": [], "update": []}
+ ctr["ready"] = {"value": [], "ph": [], "update": []}
+ ctr["best_reward"] = {"value": [], "update": []}
+ for i in range(self.hparams.num_children):
+ reward_value = variable_scope.get_local_variable(
+ "reward_{}".format(i),
+ initializer=0.0,
+ dtype=dtypes.float32,
+ trainable=False)
+ reward_ph = array_ops.placeholder(
+ dtypes.float32, shape=(), name="reward_ph_{}".format(i))
+ reward_update = state_ops.assign(
+ reward_value, reward_ph, use_locking=True)
+ ctr["reward"]["value"].append(reward_value)
+ ctr["reward"]["ph"].append(reward_ph)
+ ctr["reward"]["update"].append(reward_update)
+ best_reward = variable_scope.get_local_variable(
+ "best_reward_{}".format(i),
+ initializer=failing_signal,
+ dtype=dtypes.float32,
+ trainable=False)
+ ctr["best_reward"]["value"].append(best_reward)
+ ctr["best_reward"]["update"].append(
+ state_ops.assign(best_reward,
+ math_ops.minimum(best_reward, reward_update)))
+
+ ready_value = variable_scope.get_local_variable(
+ "ready_{}".format(i),
+ initializer=True,
+ dtype=dtypes.bool,
+ trainable=False)
+ ready_ph = array_ops.placeholder(
+ dtypes.bool, shape=(), name="ready_ph_{}".format(i))
+ ready_update = state_ops.assign(
+ ready_value, ready_ph, use_locking=True)
+ ctr["ready"]["value"].append(ready_value)
+ ctr["ready"]["ph"].append(ready_ph)
+ ctr["ready"]["update"].append(ready_update)
+
+ ctr["grouping_y_preds"], ctr["grouping_log_probs"] = self.get_groupings()
+ summary.histogram(
+ "grouping_actions",
+ array_ops.slice(ctr["grouping_y_preds"]["sample"], [0, 0],
+ [1, array_ops.shape(self.op_embeddings)[0]]))
+
+ with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+ ctr["baseline"] = variable_scope.get_local_variable(
+ "baseline",
+ initializer=failing_signal
+ if self.hparams.start_with_failing_signal else 0.0,
+ dtype=dtypes.float32,
+ trainable=False)
+
+ new_baseline = self.hparams.bl_dec * ctr["baseline"] + (
+ 1 - self.hparams.bl_dec) * math_ops.reduce_mean(
+ ctr["reward"]["value"])
+ if not self.hparams.always_update_baseline:
+ baseline_mask = math_ops.less(ctr["reward"]["value"], failing_signal)
+ selected_reward = array_ops.boolean_mask(ctr["reward"]["value"],
+ baseline_mask)
+ selected_baseline = control_flow_ops.cond(
+ math_ops.reduce_any(baseline_mask),
+ lambda: math_ops.reduce_mean(selected_reward),
+ lambda: constant_op.constant(0, dtype=dtypes.float32))
+ ctr["pos_reward"] = selected_baseline
+ pos_ = math_ops.less(
+ constant_op.constant(0, dtype=dtypes.float32), selected_baseline)
+ selected_baseline = self.hparams.bl_dec * ctr["baseline"] + (
+ 1 - self.hparams.bl_dec) * selected_baseline
+ selected_baseline = control_flow_ops.cond(
+ pos_, lambda: selected_baseline, lambda: ctr["baseline"])
+ new_baseline = control_flow_ops.cond(
+ math_ops.less(self.global_step,
+ self.hparams.stop_updating_after_steps),
+ lambda: new_baseline, lambda: selected_baseline)
+ ctr["baseline_update"] = state_ops.assign(
+ ctr["baseline"], new_baseline, use_locking=True)
+
+ ctr["y_preds"], ctr["log_probs"] = self.get_placements()
+ summary.histogram("actions", ctr["y_preds"]["sample"])
+ mask = math_ops.less(ctr["reward"]["value"], failing_signal)
+ ctr["loss"] = ctr["reward"]["value"] - ctr["baseline"]
+ ctr["loss"] *= (
+ ctr["log_probs"]["sample"] + ctr["grouping_log_probs"]["sample"])
+
+ selected_loss = array_ops.boolean_mask(ctr["loss"], mask)
+ selected_loss = control_flow_ops.cond(
+ math_ops.reduce_any(mask),
+ lambda: math_ops.reduce_mean(-selected_loss),
+ lambda: constant_op.constant(0, dtype=dtypes.float32))
+
+ ctr["loss"] = control_flow_ops.cond(
+ math_ops.less(self.global_step,
+ self.hparams.stop_updating_after_steps),
+ lambda: math_ops.reduce_mean(-ctr["loss"]), lambda: selected_loss)
+
+ ctr["reward_s"] = math_ops.reduce_mean(ctr["reward"]["value"])
+ summary.scalar("loss", ctr["loss"])
+ summary.scalar("avg_reward", ctr["reward_s"])
+ summary.scalar("best_reward_so_far", best_reward)
+ summary.scalar(
+ "advantage",
+ math_ops.reduce_mean(ctr["reward"]["value"] - ctr["baseline"]))
+
+ with variable_scope.variable_scope(
+ "optimizer", reuse=variable_scope.AUTO_REUSE):
+ (ctr["train_op"], ctr["lr"], ctr["grad_norm"],
+ ctr["grad_norms"]) = self._get_train_ops(
+ ctr["loss"],
+ tf_ops.get_collection(tf_ops.GraphKeys.TRAINABLE_VARIABLES),
+ self.global_step,
+ grad_bound=self.hparams.grad_bound,
+ lr_init=self.hparams.lr,
+ lr_dec=self.hparams.lr_dec,
+ start_decay_step=self.hparams.start_decay_step,
+ decay_steps=self.hparams.decay_steps,
+ optimizer_type=self.hparams.optimizer_type)
+
+ summary.scalar("gradnorm", ctr["grad_norm"])
+ summary.scalar("lr", ctr["lr"])
+ ctr["summary"] = summary.merge_all()
+ ops["controller"] = ctr
+
+ self.ops = ops
+ return ops
+
+ @property
+ def global_step(self):
+ return self._global_step
+
+ def create_op_embeddings(self, verbose=False):
+ if verbose:
+ print("process input graph for op embeddings")
+ self.num_ops = len(self.important_ops)
+ # topological sort of important nodes
+ topo_order = [op.name for op in self.important_ops]
+
+ # create index to name for topologicaly sorted important nodes
+ name_to_topo_order_index = {}
+ for idx, x in enumerate(topo_order):
+ name_to_topo_order_index[x] = idx
+ self.name_to_topo_order_index = name_to_topo_order_index
+
+ # create adj matrix
+ adj_dict = {}
+ for idx, op in enumerate(self.important_ops):
+ for output_op in self.get_node_fanout(op):
+ output_op_name = output_op.name
+ if output_op_name in self.important_op_names:
+ if name_to_topo_order_index[op.name] not in adj_dict:
+ adj_dict[name_to_topo_order_index[op.name]] = []
+ adj_dict[name_to_topo_order_index[op.name]].extend(
+ [name_to_topo_order_index[output_op_name], 1])
+ if output_op_name not in adj_dict:
+ adj_dict[name_to_topo_order_index[output_op_name]] = []
+ adj_dict[name_to_topo_order_index[output_op_name]].extend(
+ [name_to_topo_order_index[op.name], -1])
+
+ # get op_type op_output_shape, and adj info
+ output_embed_dim = (self.hparams.max_num_outputs *
+ self.hparams.max_output_size)
+
+ # TODO(bsteiner): don't filter based on used ops so that we can generalize
+ # to models that use other types of ops.
+ used_ops = set()
+ for node in self.important_ops:
+ op_type = str(node.op)
+ used_ops.add(op_type)
+
+ self.type_dict = {}
+ for op_type in self.cluster.ListAvailableOps():
+ if op_type in used_ops:
+ self.type_dict[op_type] = len(self.type_dict)
+
+ op_types = np.zeros([self.num_ops], dtype=np.int32)
+ op_output_shapes = np.full(
+ [self.num_ops, output_embed_dim], -1.0, dtype=np.float32)
+ for idx, node in enumerate(self.important_ops):
+ op_types[idx] = self.type_dict[node.op]
+ # output shape
+ op_name = node.name
+ for i, output_prop in enumerate(self.node_properties[op_name]):
+ if output_prop.shape.__str__() == "<unknown>":
+ continue
+ shape = output_prop.shape
+ for j, dim in enumerate(shape.dim):
+ if dim.size >= 0:
+ if i * self.hparams.max_output_size + j >= output_embed_dim:
+ break
+ op_output_shapes[idx,
+ i * self.hparams.max_output_size + j] = dim.size
+ # adj for padding
+ op_adj = np.full(
+ [self.num_ops, self.hparams.adj_embed_dim], 0, dtype=np.float32)
+ for idx in adj_dict:
+ neighbors = adj_dict[int(idx)]
+ min_dim = min(self.hparams.adj_embed_dim, len(neighbors))
+ padding_size = self.hparams.adj_embed_dim - min_dim
+ neighbors = neighbors[:min_dim] + [0] * padding_size
+ op_adj[int(idx)] = neighbors
+
+ # op_embedding starts here
+ op_embeddings = np.zeros(
+ [
+ self.num_ops,
+ 1 + self.hparams.max_num_outputs * self.hparams.max_output_size +
+ self.hparams.adj_embed_dim
+ ],
+ dtype=np.float32)
+ for idx, op_name in enumerate(topo_order):
+ op_embeddings[idx] = np.concatenate(
+ (np.array([op_types[idx]]), op_output_shapes[idx], op_adj[int(idx)]))
+ self.op_embeddings = constant_op.constant(
+ op_embeddings, dtype=dtypes.float32)
+ if verbose:
+ print("num_ops = {}".format(self.num_ops))
+ print("num_types = {}".format(len(self.type_dict)))
+
+ def get_groupings(self, *args, **kwargs):
+ num_children = self.hparams.num_children
+ with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+ grouping_actions_cache = variable_scope.get_local_variable(
+ "grouping_actions_cache",
+ initializer=init_ops.zeros_initializer,
+ dtype=dtypes.int32,
+ shape=[num_children, self.num_ops],
+ trainable=False)
+ input_layer = self.op_embeddings
+ input_layer = array_ops.expand_dims(input_layer, 0)
+ feed_ff_input_layer = array_ops.tile(input_layer, [num_children, 1, 1])
+ grouping_actions, grouping_log_probs = {}, {}
+ grouping_actions["sample"], grouping_log_probs[
+ "sample"] = self.make_grouping_predictions(feed_ff_input_layer)
+
+ grouping_actions["sample"] = state_ops.assign(grouping_actions_cache,
+ grouping_actions["sample"])
+ self.grouping_actions_cache = grouping_actions_cache
+
+ return grouping_actions, grouping_log_probs
+
+ def make_grouping_predictions(self, input_layer, reuse=None):
+ """model that predicts grouping (grouping_actions).
+
+ Args:
+ input_layer: group_input_layer
+ reuse: reuse
+
+ Returns:
+ grouping_actions: actions
+ grouping_log_probs: log probabilities corresponding to actions
+ """
+ with variable_scope.variable_scope(self.hparams.name, reuse=True):
+ # input_layer: tensor of size [1, num_ops, hidden_size]
+ w_grouping_ff = variable_scope.get_variable("w_grouping_ff")
+ w_grouping_softmax = variable_scope.get_variable("w_grouping_softmax")
+
+ batch_size = array_ops.shape(input_layer)[0]
+ embedding_dim = array_ops.shape(input_layer)[2]
+
+ reshaped = array_ops.reshape(input_layer,
+ [batch_size * self.num_ops, embedding_dim])
+ ff_output = math_ops.matmul(reshaped, w_grouping_ff)
+ logits = math_ops.matmul(ff_output, w_grouping_softmax)
+ if self.hparams.logits_std_noise > 0:
+ num_in_logits = math_ops.cast(
+ array_ops.size(logits), dtype=dtypes.float32)
+ avg_norm = math_ops.divide(
+ linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
+ logits_noise = random_ops.random_normal(
+ array_ops.shape(logits),
+ stddev=self.hparams.logits_std_noise * avg_norm)
+ logits = control_flow_ops.cond(
+ self.global_step > self.hparams.stop_noise_step, lambda: logits,
+ lambda: logits + logits_noise)
+ logits = array_ops.reshape(logits,
+ [batch_size * self.num_ops, self.num_groups])
+ actions = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
+ actions = math_ops.to_int32(actions)
+ actions = array_ops.reshape(actions, [batch_size, self.num_ops])
+ action_label = array_ops.reshape(actions, [-1])
+ log_probs = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ logits=logits, labels=action_label)
+ log_probs = array_ops.reshape(log_probs, [batch_size, -1])
+ log_probs = math_ops.reduce_sum(log_probs, 1)
+ grouping_actions = actions
+ grouping_log_probs = log_probs
+ return grouping_actions, grouping_log_probs
+
+ def create_group_embeddings(self, grouping_actions, verbose=False):
+ """Approximating the blocks of a TF graph from a graph_def.
+
+ Args:
+ grouping_actions: grouping predictions
+ verbose: print stuffs.
+
+ Returns:
+ groups: list of groups.
+ """
+ if verbose:
+ print("Processing input_graph")
+
+ # TODO(azalia): Build inter-adjacencies dag matrix.
+ # record dag_matrix
+ dag_matrix = np.zeros([self.num_groups, self.num_groups], dtype=np.float32)
+ for op in self.important_ops:
+ topo_op_index = self.name_to_topo_order_index[op.name]
+ # TODO(agoldie) child_id
+ group_index = grouping_actions[0][topo_op_index]
+ for output_op in self.get_node_fanout(op):
+ if output_op.name not in self.important_op_names:
+ continue
+ output_group_index = grouping_actions[0][self.name_to_topo_order_index[
+ output_op.name]]
+ dag_matrix[group_index, output_group_index] += 1.0
+ num_connections = np.sum(dag_matrix)
+ num_intra_group_connections = dag_matrix.trace()
+ num_inter_group_connections = num_connections - num_intra_group_connections
+ if verbose:
+ print("grouping evaluation metric")
+ print("num_connections={} num_intra_group_connections={} "
+ "num_inter_group_connections={}").format(
+ num_connections, num_intra_group_connections,
+ num_inter_group_connections)
+ self.dag_matrix = dag_matrix
+
+ # output_shape
+ op_output_shapes = np.zeros(
+ [
+ len(self.important_ops),
+ self.hparams.max_num_outputs * self.hparams.max_output_size
+ ],
+ dtype=np.float32)
+
+ for idx, op in enumerate(self.important_ops):
+ for i, output_properties in enumerate(self.node_properties[op.name]):
+ if output_properties.shape.__str__() == "<unknown>":
+ continue
+ if i > self.hparams.max_num_outputs:
+ break
+ shape = output_properties.shape
+ for j, dim in enumerate(shape.dim):
+ if dim.size > 0:
+ k = i * self.hparams.max_output_size + j
+ if k >= self.hparams.max_num_outputs * self.hparams.max_output_size:
+ break
+ op_output_shapes[idx, k] = dim.size
+
+ # group_embedding
+ group_embedding = np.zeros(
+ [
+ self.num_groups, len(self.type_dict) +
+ self.hparams.max_num_outputs * self.hparams.max_output_size
+ ],
+ dtype=np.float32)
+ for op_index, op in enumerate(self.important_ops):
+ group_index = grouping_actions[0][self.name_to_topo_order_index[op.name]]
+ type_name = str(op.op)
+ type_index = self.type_dict[type_name]
+ group_embedding[group_index, type_index] += 1
+ group_embedding[group_index, :self.hparams.max_num_outputs * self.hparams.
+ max_output_size] += (
+ op_output_shapes[op_index])
+ grouping_adjacencies = np.concatenate(
+ [dag_matrix, np.transpose(dag_matrix)], axis=1)
+ group_embedding = np.concatenate(
+ [grouping_adjacencies, group_embedding], axis=1)
+ group_normalizer = np.amax(group_embedding, axis=1, keepdims=True)
+ group_embedding /= (group_normalizer + 1.0)
+ if verbose:
+ print("Finished Processing Input Graph")
+ return group_embedding
+
+ def get_placements(self, *args, **kwargs):
+ num_children = self.hparams.num_children
+ with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+ actions_cache = variable_scope.get_local_variable(
+ "actions_cache",
+ initializer=init_ops.zeros_initializer,
+ dtype=dtypes.int32,
+ shape=[num_children, self.num_groups],
+ trainable=False)
+
+ x = array_ops.tile(self.seq2seq_input_layer, [num_children, 1, 1])
+ last_c, last_h, attn_mem = self.encode(x)
+ actions, log_probs = {}, {}
+ actions["sample"], log_probs["sample"] = (
+ self.decode(
+ x, last_c, last_h, attn_mem, mode="sample"))
+ actions["target"], log_probs["target"] = (
+ self.decode(
+ x,
+ last_c,
+ last_h,
+ attn_mem,
+ mode="target",
+ y=actions_cache))
+ actions["greedy"], log_probs["greedy"] = (
+ self.decode(
+ x, last_c, last_h, attn_mem, mode="greedy"))
+ actions["sample"] = control_flow_ops.cond(
+ self.global_step < self.hparams.stop_sampling,
+ lambda: state_ops.assign(actions_cache, actions["sample"]),
+ lambda: state_ops.assign(actions_cache, actions["target"]))
+ self.actions_cache = actions_cache
+
+ return actions, log_probs
+
+ def encode(self, x):
+ """Encoder using LSTM.
+
+ Args:
+ x: tensor of size [num_children, num_groups, embedding_size]
+
+ Returns:
+ last_c, last_h: tensors of size [num_children, hidden_size], the final
+ LSTM states
+ attn_mem: tensor of size [num_children, num_groups, hidden_size], the
+ attention
+ memory, i.e. concatenation of all hidden states, linearly transformed by
+ an attention matrix attn_w_1
+ """
+ if self.hparams.bi_lstm:
+ with variable_scope.variable_scope(self.hparams.name, reuse=True):
+ w_lstm_forward = variable_scope.get_variable("encoder_lstm_forward")
+ w_lstm_backward = variable_scope.get_variable("encoder_lstm_backward")
+ forget_bias = variable_scope.get_variable("encoder_forget_bias")
+ attn_w_1 = variable_scope.get_variable("attn_w_1")
+ else:
+ with variable_scope.variable_scope(self.hparams.name, reuse=True):
+ w_lstm = variable_scope.get_variable("encoder_lstm")
+ forget_bias = variable_scope.get_variable("encoder_forget_bias")
+ attn_w_1 = variable_scope.get_variable("attn_w_1")
+
+ embedding_size = array_ops.shape(x)[2]
+
+ signals = array_ops.split(x, self.num_groups, axis=1)
+ for i in range(len(signals)):
+ signals[i] = array_ops.reshape(
+ signals[i], [self.hparams.num_children, embedding_size])
+
+ if self.hparams.bi_lstm:
+
+ def body(i, prev_c_forward, prev_h_forward, prev_c_backward,
+ prev_h_backward):
+ """while loop for LSTM."""
+ signal_forward = signals[i]
+ next_c_forward, next_h_forward = lstm(signal_forward, prev_c_forward,
+ prev_h_forward, w_lstm_forward,
+ forget_bias)
+
+ signal_backward = signals[self.num_groups - 1 - i]
+ next_c_backward, next_h_backward = lstm(
+ signal_backward, prev_c_backward, prev_h_backward, w_lstm_backward,
+ forget_bias)
+
+ next_h = array_ops.concat([next_h_forward, next_h_backward], axis=1)
+ all_h.append(next_h)
+
+ return (next_c_forward, next_h_forward, next_c_backward,
+ next_h_backward)
+
+ c_forward = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size / 2],
+ dtype=dtypes.float32)
+ h_forward = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size / 2],
+ dtype=dtypes.float32)
+
+ c_backward = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size / 2],
+ dtype=dtypes.float32)
+ h_backward = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size / 2],
+ dtype=dtypes.float32)
+ all_h = []
+
+ for i in range(0, self.num_groups):
+ c_forward, h_forward, c_backward, h_backward = body(
+ i, c_forward, h_forward, c_backward, h_backward)
+
+ last_c = array_ops.concat([c_forward, c_backward], axis=1)
+ last_h = array_ops.concat([h_forward, h_backward], axis=1)
+ attn_mem = array_ops.stack(all_h)
+
+ else:
+
+ def body(i, prev_c, prev_h):
+ signal = signals[i]
+ next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias)
+ all_h.append(next_h)
+ return next_c, next_h
+
+ c = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size],
+ dtype=dtypes.float32)
+ h = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size],
+ dtype=dtypes.float32)
+ all_h = []
+
+ for i in range(0, self.num_groups):
+ c, h = body(i, c, h)
+
+ last_c = c
+ last_h = h
+ attn_mem = array_ops.stack(all_h)
+
+ attn_mem = array_ops.transpose(attn_mem, [1, 0, 2])
+ attn_mem = array_ops.reshape(
+ attn_mem,
+ [self.hparams.num_children * self.num_groups, self.hparams.hidden_size])
+ attn_mem = math_ops.matmul(attn_mem, attn_w_1)
+ attn_mem = array_ops.reshape(
+ attn_mem,
+ [self.hparams.num_children, self.num_groups, self.hparams.hidden_size])
+
+ return last_c, last_h, attn_mem
+
+ def decode(self,
+ x,
+ last_c,
+ last_h,
+ attn_mem,
+ mode="target",
+ y=None):
+ """Decoder using LSTM.
+
+ Args:
+ x: tensor of size [num_children, num_groups, embedding_size].
+ last_c: tensor of size [num_children, hidden_size], the final LSTM states
+ computed by self.encoder.
+ last_h: same as last_c.
+ attn_mem: tensor of size [num_children, num_groups, hidden_size].
+ mode: "target" or "sample".
+ y: tensor of size [num_children, num_groups], the device placements.
+
+ Returns:
+ actions: tensor of size [num_children, num_groups], the placements of
+ devices
+ """
+ with variable_scope.variable_scope(self.hparams.name, reuse=True):
+ w_lstm = variable_scope.get_variable("decoder_lstm")
+ forget_bias = variable_scope.get_variable("decoder_forget_bias")
+ device_embeddings = variable_scope.get_variable("device_embeddings")
+ device_softmax = variable_scope.get_variable("device_softmax")
+ device_go_embedding = variable_scope.get_variable("device_go_embedding")
+ attn_w_2 = variable_scope.get_variable("attn_w_2")
+ attn_v = variable_scope.get_variable("attn_v")
+
+ actions = tensor_array_ops.TensorArray(
+ dtypes.int32,
+ size=self.num_groups,
+ infer_shape=False,
+ clear_after_read=False)
+
+ # pylint: disable=unused-argument
+ def condition(i, *args):
+ return math_ops.less(i, self.num_groups)
+
+ # pylint: disable=missing-docstring
+ def body(i, prev_c, prev_h, actions, log_probs):
+ # pylint: disable=g-long-lambda
+ signal = control_flow_ops.cond(
+ math_ops.equal(i, 0),
+ lambda: array_ops.tile(device_go_embedding,
+ [self.hparams.num_children, 1]),
+ lambda: embedding_ops.embedding_lookup(device_embeddings,
+ actions.read(i - 1))
+ )
+ if self.hparams.keep_prob is not None:
+ signal = nn_ops.dropout(signal, self.hparams.keep_prob)
+ next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias)
+ query = math_ops.matmul(next_h, attn_w_2)
+ query = array_ops.reshape(
+ query, [self.hparams.num_children, 1, self.hparams.hidden_size])
+ query = math_ops.tanh(query + attn_mem)
+ query = array_ops.reshape(query, [
+ self.hparams.num_children * self.num_groups, self.hparams.hidden_size
+ ])
+ query = math_ops.matmul(query, attn_v)
+ query = array_ops.reshape(query,
+ [self.hparams.num_children, self.num_groups])
+ query = nn_ops.softmax(query)
+ query = array_ops.reshape(query,
+ [self.hparams.num_children, self.num_groups, 1])
+ query = math_ops.reduce_sum(attn_mem * query, axis=1)
+ query = array_ops.concat([next_h, query], axis=1)
+ logits = math_ops.matmul(query, device_softmax)
+ logits /= self.hparams.temperature
+ if self.hparams.tanh_constant > 0:
+ logits = math_ops.tanh(logits) * self.hparams.tanh_constant
+ if self.hparams.logits_std_noise > 0:
+ num_in_logits = math_ops.cast(
+ array_ops.size(logits), dtype=dtypes.float32)
+ avg_norm = math_ops.divide(
+ linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
+ logits_noise = random_ops.random_normal(
+ array_ops.shape(logits),
+ stddev=self.hparams.logits_std_noise * avg_norm)
+ logits = control_flow_ops.cond(
+ self.global_step > self.hparams.stop_noise_step, lambda: logits,
+ lambda: logits + logits_noise)
+
+ if mode == "sample":
+ next_y = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
+ elif mode == "greedy":
+ next_y = math_ops.argmax(logits, 1)
+ elif mode == "target":
+ next_y = array_ops.slice(y, [0, i], [-1, 1])
+ else:
+ raise NotImplementedError
+ next_y = math_ops.to_int32(next_y)
+ next_y = array_ops.reshape(next_y, [self.hparams.num_children])
+ actions = actions.write(i, next_y)
+ log_probs += nn_ops.sparse_softmax_cross_entropy_with_logits(
+ logits=logits, labels=next_y)
+ return i + 1, next_c, next_h, actions, log_probs
+
+ loop_vars = [
+ constant_op.constant(0, dtype=dtypes.int32), last_c, last_h, actions,
+ array_ops.zeros([self.hparams.num_children], dtype=dtypes.float32)
+ ]
+ loop_outputs = control_flow_ops.while_loop(condition, body, loop_vars)
+
+ last_c = loop_outputs[-4]
+ last_h = loop_outputs[-3]
+ actions = loop_outputs[-2].stack()
+ actions = array_ops.transpose(actions, [1, 0])
+ log_probs = loop_outputs[-1]
+ return actions, log_probs
+
+ def eval_placement(self,
+ sess,
+ child_id=0,
+ verbose=False):
+ grouping_actions, actions = sess.run([
+ self.grouping_actions_cache,
+ self.actions_cache
+ ])
+ grouping_actions = grouping_actions[child_id]
+ actions = actions[child_id]
+ if verbose:
+ global_step = sess.run(self.global_step)
+ if global_step % 100 == 0:
+ log_string = "op group assignments: "
+ for a in grouping_actions:
+ log_string += "{} ".format(a)
+ print(log_string[:-1])
+ log_string = "group device assignments: "
+ for a in actions:
+ log_string += "{} ".format(a)
+ print(log_string[:-1])
+
+ for op in self.important_ops:
+ topo_order_index = self.name_to_topo_order_index[op.name]
+ group_index = grouping_actions[topo_order_index]
+ op.device = self.devices[actions[group_index]].name
+ try:
+ _, run_time, _ = self.cluster.MeasureCosts(self.item)
+ except errors.ResourceExhaustedError:
+ run_time = self.hparams.failing_signal
+ return run_time
+
+ def update_reward(self,
+ sess,
+ run_time,
+ child_id=0,
+ verbose=False):
+ reward = self.compute_reward(run_time)
+ controller_ops = self.ops["controller"]
+ _, best_reward = sess.run(
+ [
+ controller_ops["reward"]["update"][child_id],
+ controller_ops["best_reward"]["update"][child_id]
+ ],
+ feed_dict={
+ controller_ops["reward"]["ph"][child_id]: reward,
+ })
+ if verbose:
+ print("run_time={:<.5f} reward={:<.5f} "
+ "best_reward={:<.5f}").format(run_time, reward, best_reward)
+
+ # Reward is a double, best_reward a float: allow for some slack in the
+ # comparison.
+ updated = abs(best_reward - reward) < 1e-6
+ return updated
+
+ def generate_grouping(self, sess):
+ controller_ops = self.ops["controller"]
+ grouping_actions = sess.run(controller_ops["grouping_y_preds"]["sample"])
+ return grouping_actions
+
+ def generate_placement(self, grouping, sess):
+ controller_ops = self.ops["controller"]
+ feed_seq2seq_input_dict = {}
+ feed_seq2seq_input_dict[self.seq2seq_input_layer] = np.expand_dims(
+ grouping, axis=0)
+ sess.run(
+ controller_ops["y_preds"]["sample"], feed_dict=feed_seq2seq_input_dict)
+
+ def process_reward(self, sess):
+ controller_ops = self.ops["controller"]
+ run_ops = [
+ controller_ops["loss"], controller_ops["lr"],
+ controller_ops["grad_norm"], controller_ops["grad_norms"],
+ controller_ops["train_op"]
+ ]
+ sess.run(run_ops)
+ sess.run(controller_ops["baseline_update"])
+
+ def _get_train_ops(self,
+ loss,
+ tf_variables,
+ global_step,
+ grad_bound=1.25,
+ lr_init=1e-3,
+ lr_dec=0.9,
+ start_decay_step=10000,
+ decay_steps=100,
+ optimizer_type="adam"):
+ """Loss optimizer.
+
+ Args:
+ loss: scalar tf tensor
+ tf_variables: list of training variables, typically
+ tf.trainable_variables()
+ global_step: global_step
+ grad_bound: max gradient norm
+ lr_init: initial learning rate
+ lr_dec: leaning rate decay coefficient
+ start_decay_step: start decaying learning rate after this many steps
+ decay_steps: apply decay rate factor at this step intervals
+ optimizer_type: optimizer type should be either adam or sgd
+
+ Returns:
+ train_op: training op
+ learning_rate: scalar learning rate tensor
+ grad_norm: l2 norm of the gradient vector
+ all_grad_norms: l2 norm of each component
+ """
+ lr_gstep = global_step - start_decay_step
+
+ def f1():
+ return constant_op.constant(lr_init)
+
+ def f2():
+ return learning_rate_decay.exponential_decay(lr_init, lr_gstep,
+ decay_steps, lr_dec, True)
+
+ learning_rate = control_flow_ops.cond(
+ math_ops.less(global_step, start_decay_step),
+ f1,
+ f2,
+ name="learning_rate")
+
+ if optimizer_type == "adam":
+ opt = adam.AdamOptimizer(learning_rate)
+ elif optimizer_type == "sgd":
+ opt = gradient_descent.GradientDescentOptimizer(learning_rate)
+ grads_and_vars = opt.compute_gradients(loss, tf_variables)
+ grad_norm = clip_ops.global_norm([g for g, v in grads_and_vars])
+ all_grad_norms = {}
+ clipped_grads = []
+ clipped_rate = math_ops.maximum(grad_norm / grad_bound, 1.0)
+ for g, v in grads_and_vars:
+ if g is not None:
+ if isinstance(g, tf_ops.IndexedSlices):
+ clipped = g.values / clipped_rate
+ norm_square = math_ops.reduce_sum(clipped * clipped)
+ clipped = tf_ops.IndexedSlices(clipped, g.indices)
+ else:
+ clipped = g / clipped_rate
+ norm_square = math_ops.reduce_sum(clipped * clipped)
+ all_grad_norms[v.name] = math_ops.sqrt(norm_square)
+ clipped_grads.append((clipped, v))
+
+ train_op = opt.apply_gradients(clipped_grads, global_step)
+ return train_op, learning_rate, grad_norm, all_grad_norms
+
+
+def lstm(x, prev_c, prev_h, w_lstm, forget_bias):
+ """LSTM cell.
+
+ Args:
+ x: tensors of size [num_children, hidden_size].
+ prev_c: tensors of size [num_children, hidden_size].
+ prev_h: same as prev_c.
+ w_lstm: .
+ forget_bias: .
+
+ Returns:
+ next_c:
+ next_h:
+ """
+ ifog = math_ops.matmul(array_ops.concat([x, prev_h], axis=1), w_lstm)
+ i, f, o, g = array_ops.split(ifog, 4, axis=1)
+ i = math_ops.sigmoid(i)
+ f = math_ops.sigmoid(f + forget_bias)
+ o = math_ops.sigmoid(o)
+ g = math_ops.tanh(g)
+ next_c = i * g + f * prev_c
+ next_h = o * math_ops.tanh(next_c)
+ return next_c, next_h
diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i
index d0fc1a04f2..9a84c60b04 100644
--- a/tensorflow/python/grappler/item.i
+++ b/tensorflow/python/grappler/item.i
@@ -96,10 +96,10 @@ static GItem TF_NewItem(
return GItem(item.release());
}
-static std::vector<string> TF_IdentifyImportantOps(GItem item, bool sort_topologically,
+static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
TF_Status* status) {
if (item.is_none()) {
- return {};
+ Py_RETURN_NONE;
}
std::vector<const tensorflow::NodeDef*> main_ops = item->MainOpsFanin();
@@ -132,7 +132,13 @@ static std::vector<string> TF_IdentifyImportantOps(GItem item, bool sort_topolog
}
}
- return ops;
+ PyGILState_STATE gstate = PyGILState_Ensure();
+ PyObject* result = PyList_New(ops.size());
+ for (int i = 0; i < ops.size(); ++i) {
+ PyList_SetItem(result, i, PyString_FromString(ops[i].c_str()));
+ }
+ PyGILState_Release(gstate);
+ return result;
}
static PyObject* TF_GetOpProperties(GItem item) {
@@ -305,7 +311,7 @@ static PyObject* TF_GetColocationGroups(GItem item) {
static GItem TF_NewItem(
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
bool ignore_user_placement, TF_Status* out_status);
-static std::vector<string> TF_IdentifyImportantOps(GItem item, bool sort_topologically,
- TF_Status* status);
+static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
+ TF_Status* status);
static PyObject* TF_GetOpProperties(GItem item);
static PyObject* TF_GetColocationGroups(GItem item);
diff --git a/tensorflow/python/grappler/item_test.py b/tensorflow/python/grappler/item_test.py
index cd70e2fdec..7c3efd6249 100644
--- a/tensorflow/python/grappler/item_test.py
+++ b/tensorflow/python/grappler/item_test.py
@@ -56,7 +56,7 @@ class ItemTest(test.TestCase):
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_list = grappler_item.IdentifyImportantOps()
- self.assertItemsEqual([b'Const', b'Const_1', b'add'], op_list)
+ self.assertItemsEqual(['Const', 'Const_1', 'add'], op_list)
def testOpProperties(self):
with ops.Graph().as_default() as g: