aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-16 17:21:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-16 17:24:57 -0800
commitaaac4ac3e9d1d8c48db9e4010459a417a07553d2 (patch)
treeb4baf3185500bc188e08277f3b80c3a57b8dea89 /tensorflow/contrib/kfac
parentcf327e8560fc044ab37e6a766c852e7b6546f228 (diff)
K-FAC: Example using tf.estimator and K-FAC.
- Removes FisherEstimator.inv_updates_dict. Users should create directly from FisherEstimator.inv_update_ops. - Adds (cov|inv)_update_(thunks|ops) to KfacOptimizer. PiperOrigin-RevId: 182135826
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/examples/convnet.py2
-rw-r--r--tensorflow/contrib/kfac/examples/mlp.py82
-rw-r--r--tensorflow/contrib/kfac/examples/mlp_mnist_main.py10
-rw-r--r--tensorflow/contrib/kfac/examples/tests/mlp_test.py5
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py5
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py28
6 files changed, 121 insertions, 11 deletions
diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py
index 558bc294bc..39d80addaa 100644
--- a/tensorflow/contrib/kfac/examples/convnet.py
+++ b/tensorflow/contrib/kfac/examples/convnet.py
@@ -286,7 +286,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master,
damping=0.001,
layer_collection=layer_collection,
momentum=0.9)
- inv_update_queue = oq.OpQueue(optimizer.inv_updates_dict.values())
+ inv_update_queue = oq.OpQueue(optimizer.inv_update_ops)
sync_optimizer = tf.train.SyncReplicasOptimizer(
opt=optimizer,
replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks))
diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py
index 4275ceadc2..0f0dbb53f4 100644
--- a/tensorflow/contrib/kfac/examples/mlp.py
+++ b/tensorflow/contrib/kfac/examples/mlp.py
@@ -239,3 +239,85 @@ def train_mnist_multitower(data_dir,
})
return minimize(
loss, accuracy, layer_collection, session_config=session_config)
+
+
+def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
+ """Train an MLP on MNIST using tf.estimator.
+
+ Args:
+ data_dir: string. Directory to read MNIST examples from.
+ num_epochs: int. Number of passes to make over the training set.
+ use_fake_data: bool. If True, generate a synthetic dataset.
+
+ Returns:
+ accuracy of model on the final minibatch of training data.
+ """
+
+ # Load a dataset.
+ def input_fn():
+ tf.logging.info("Loading MNIST into memory.")
+ return mnist.load_mnist(
+ data_dir,
+ num_epochs=num_epochs,
+ batch_size=64,
+ flatten_images=True,
+ use_fake_data=use_fake_data)
+
+ def model_fn(features, labels, mode, params):
+ """Model function for MLP trained with K-FAC.
+
+ Args:
+ features: Tensor of shape [batch_size, input_size]. Input features.
+ labels: Tensor of shape [batch_size]. Target labels for training.
+ mode: tf.estimator.ModeKey. Must be TRAIN.
+ params: ignored.
+
+ Returns:
+ EstimatorSpec for training.
+
+ Raises:
+ ValueError: If 'mode' is anything other than TRAIN.
+ """
+ del params
+
+ if mode != tf.estimator.ModeKeys.TRAIN:
+ raise ValueError("Only training is supposed with this API.")
+
+ # Build a ConvNet.
+ layer_collection = lc.LayerCollection()
+ loss, accuracy = build_model(
+ features, labels, num_labels=10, layer_collection=layer_collection)
+
+ # Train with K-FAC.
+ global_step = tf.train.get_or_create_global_step()
+ optimizer = opt.KfacOptimizer(
+ learning_rate=tf.train.exponential_decay(
+ 0.00002, global_step, 10000, 0.5, staircase=True),
+ cov_ema_decay=0.95,
+ damping=0.0001,
+ layer_collection=layer_collection,
+ momentum=0.99)
+
+ # Run cov_update_op every step. Run 1 inv_update_ops per step.
+ cov_update_op = optimizer.cov_update_op
+ inv_update_op = tf.group(
+ tf.contrib.kfac.utils.batch_execute(
+ global_step, optimizer.inv_update_thunks, batch_size=1))
+ with tf.control_dependencies([cov_update_op, inv_update_op]):
+ train_op = optimizer.minimize(loss, global_step=global_step)
+
+ # Print metrics every 5 sec.
+ hooks = [
+ tf.train.LoggingTensorHook(
+ {
+ "loss": loss,
+ "accuracy": accuracy
+ }, every_n_secs=5),
+ ]
+ return tf.estimator.EstimatorSpec(
+ mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)
+
+ # Train until input_fn() is empty with Estimator. This is a prerequisite for
+ # TPU compatibility.
+ estimator = tf.estimator.Estimator(model_fn=model_fn)
+ estimator.train(input_fn=input_fn)
diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py
index b318c71a56..9c34ade1d2 100644
--- a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py
+++ b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py
@@ -33,7 +33,11 @@ FLAGS = None
def main(argv):
_ = argv
- if FLAGS.num_towers > 1:
+ if FLAGS.use_estimator:
+ if FLAGS.num_towers != 1:
+ raise ValueError("Only 1 device supported in tf.estimator example.")
+ mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200)
+ elif FLAGS.num_towers > 1:
mlp.train_mnist_multitower(
FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
else:
@@ -52,5 +56,9 @@ if __name__ == "__main__":
type=int,
default=1,
help="Number of CPUs to split minibatch across.")
+ parser.add_argument(
+ "--use_estimator",
+ action="store_true",
+ help="Use tf.estimator API to train.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py
index 34a942d27f..22da6c29f1 100644
--- a/tensorflow/contrib/kfac/examples/tests/mlp_test.py
+++ b/tensorflow/contrib/kfac/examples/tests/mlp_test.py
@@ -53,6 +53,11 @@ class MlpTest(tf.test.TestCase):
mlp.train_mnist_multitower(
data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
+ def testTrainMnistEstimator(self):
+ with tf.Graph().as_default():
+ # Ensure model training doesn't crash.
+ mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True)
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index d66395ded7..a7b1f9d35c 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -281,11 +281,6 @@ class FisherEstimator(object):
return thunk
- @property
- def inv_updates_dict(self):
- """Returns a dictionary mapping strings to inv_update_ops."""
- return {op.name: op for op in self.inv_update_ops}
-
def _get_grads_lists_gradients(self, tensors):
grads_flat = gradients_impl.gradients(
self._layers.total_sampled_loss(),
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
index 0c9444241f..1974b07acf 100644
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ b/tensorflow/contrib/kfac/python/ops/optimizer.py
@@ -137,13 +137,33 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0]
self._losses = layer_collection.losses
- self.cov_update_op = self._fisher_est.cov_update_op
- self.inv_update_op = self._fisher_est.inv_update_op
- self.inv_updates_dict = self._fisher_est.inv_updates_dict
-
super(KfacOptimizer, self).__init__(learning_rate, name=name)
@property
+ def cov_update_thunks(self):
+ return self._fisher_est.cov_update_thunks
+
+ @property
+ def cov_update_ops(self):
+ return self._fisher_est.cov_update_ops
+
+ @property
+ def cov_update_op(self):
+ return self._fisher_est.cov_update_op
+
+ @property
+ def inv_update_thunks(self):
+ return self._fisher_est.inv_update_thunks
+
+ @property
+ def inv_update_ops(self):
+ return self._fisher_est.inv_update_ops
+
+ @property
+ def inv_update_op(self):
+ return self._fisher_est.inv_update_op
+
+ @property
def variables(self):
return self._fisher_est.variables