diff options
author | 2018-01-16 17:21:12 -0800 | |
---|---|---|
committer | 2018-01-16 17:24:57 -0800 | |
commit | aaac4ac3e9d1d8c48db9e4010459a417a07553d2 (patch) | |
tree | b4baf3185500bc188e08277f3b80c3a57b8dea89 /tensorflow/contrib/kfac | |
parent | cf327e8560fc044ab37e6a766c852e7b6546f228 (diff) |
K-FAC: Example using tf.estimator and K-FAC.
- Removes FisherEstimator.inv_updates_dict. Users should create directly from
FisherEstimator.inv_update_ops.
- Adds (cov|inv)_update_(thunks|ops) to KfacOptimizer.
PiperOrigin-RevId: 182135826
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r-- | tensorflow/contrib/kfac/examples/convnet.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/examples/mlp.py | 82 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/examples/mlp_mnist_main.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/examples/tests/mlp_test.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/estimator.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/optimizer.py | 28 |
6 files changed, 121 insertions, 11 deletions
diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py index 558bc294bc..39d80addaa 100644 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -286,7 +286,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, damping=0.001, layer_collection=layer_collection, momentum=0.9) - inv_update_queue = oq.OpQueue(optimizer.inv_updates_dict.values()) + inv_update_queue = oq.OpQueue(optimizer.inv_update_ops) sync_optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks)) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py index 4275ceadc2..0f0dbb53f4 100644 --- a/tensorflow/contrib/kfac/examples/mlp.py +++ b/tensorflow/contrib/kfac/examples/mlp.py @@ -239,3 +239,85 @@ def train_mnist_multitower(data_dir, }) return minimize( loss, accuracy, layer_collection, session_config=session_config) + + +def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): + """Train an MLP on MNIST using tf.estimator. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + """ + + # Load a dataset. + def input_fn(): + tf.logging.info("Loading MNIST into memory.") + return mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=64, + flatten_images=True, + use_fake_data=use_fake_data) + + def model_fn(features, labels, mode, params): + """Model function for MLP trained with K-FAC. + + Args: + features: Tensor of shape [batch_size, input_size]. Input features. + labels: Tensor of shape [batch_size]. Target labels for training. + mode: tf.estimator.ModeKey. Must be TRAIN. + params: ignored. + + Returns: + EstimatorSpec for training. + + Raises: + ValueError: If 'mode' is anything other than TRAIN. + """ + del params + + if mode != tf.estimator.ModeKeys.TRAIN: + raise ValueError("Only training is supposed with this API.") + + # Build a ConvNet. + layer_collection = lc.LayerCollection() + loss, accuracy = build_model( + features, labels, num_labels=10, layer_collection=layer_collection) + + # Train with K-FAC. + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=tf.train.exponential_decay( + 0.00002, global_step, 10000, 0.5, staircase=True), + cov_ema_decay=0.95, + damping=0.0001, + layer_collection=layer_collection, + momentum=0.99) + + # Run cov_update_op every step. Run 1 inv_update_ops per step. + cov_update_op = optimizer.cov_update_op + inv_update_op = tf.group( + tf.contrib.kfac.utils.batch_execute( + global_step, optimizer.inv_update_thunks, batch_size=1)) + with tf.control_dependencies([cov_update_op, inv_update_op]): + train_op = optimizer.minimize(loss, global_step=global_step) + + # Print metrics every 5 sec. + hooks = [ + tf.train.LoggingTensorHook( + { + "loss": loss, + "accuracy": accuracy + }, every_n_secs=5), + ] + return tf.estimator.EstimatorSpec( + mode=mode, loss=loss, train_op=train_op, training_hooks=hooks) + + # Train until input_fn() is empty with Estimator. This is a prerequisite for + # TPU compatibility. + estimator = tf.estimator.Estimator(model_fn=model_fn) + estimator.train(input_fn=input_fn) diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py index b318c71a56..9c34ade1d2 100644 --- a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py +++ b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py @@ -33,7 +33,11 @@ FLAGS = None def main(argv): _ = argv - if FLAGS.num_towers > 1: + if FLAGS.use_estimator: + if FLAGS.num_towers != 1: + raise ValueError("Only 1 device supported in tf.estimator example.") + mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200) + elif FLAGS.num_towers > 1: mlp.train_mnist_multitower( FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) else: @@ -52,5 +56,9 @@ if __name__ == "__main__": type=int, default=1, help="Number of CPUs to split minibatch across.") + parser.add_argument( + "--use_estimator", + action="store_true", + help="Use tf.estimator API to train.") FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py index 34a942d27f..22da6c29f1 100644 --- a/tensorflow/contrib/kfac/examples/tests/mlp_test.py +++ b/tensorflow/contrib/kfac/examples/tests/mlp_test.py @@ -53,6 +53,11 @@ class MlpTest(tf.test.TestCase): mlp.train_mnist_multitower( data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True) + def testTrainMnistEstimator(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index d66395ded7..a7b1f9d35c 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -281,11 +281,6 @@ class FisherEstimator(object): return thunk - @property - def inv_updates_dict(self): - """Returns a dictionary mapping strings to inv_update_ops.""" - return {op.name: op for op in self.inv_update_ops} - def _get_grads_lists_gradients(self, tensors): grads_flat = gradients_impl.gradients( self._layers.total_sampled_loss(), diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 0c9444241f..1974b07acf 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -137,13 +137,33 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0] self._losses = layer_collection.losses - self.cov_update_op = self._fisher_est.cov_update_op - self.inv_update_op = self._fisher_est.inv_update_op - self.inv_updates_dict = self._fisher_est.inv_updates_dict - super(KfacOptimizer, self).__init__(learning_rate, name=name) @property + def cov_update_thunks(self): + return self._fisher_est.cov_update_thunks + + @property + def cov_update_ops(self): + return self._fisher_est.cov_update_ops + + @property + def cov_update_op(self): + return self._fisher_est.cov_update_op + + @property + def inv_update_thunks(self): + return self._fisher_est.inv_update_thunks + + @property + def inv_update_ops(self): + return self._fisher_est.inv_update_ops + + @property + def inv_update_op(self): + return self._fisher_est.inv_update_op + + @property def variables(self): return self._fisher_est.variables |