aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-30 11:47:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-30 12:52:30 -0800
commit6558e37454de83652b5a9c5beb8f9230faecc7be (patch)
treee93567c07dbf5f9942877820002648185c47cf91 /tensorflow/contrib/kfac
parent4c30f9676915704f1e0c31fa7a0729e375cb8412 (diff)
K-FAC: expose set_global_constants() for tf.contrib.kfac.utils
PiperOrigin-RevId: 183867014
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/examples/mlp.py5
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py1
2 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py
index 0f0dbb53f4..87eed03888 100644
--- a/tensorflow/contrib/kfac/examples/mlp.py
+++ b/tensorflow/contrib/kfac/examples/mlp.py
@@ -317,7 +317,10 @@ def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)
+ run_config = tf.estimator.RunConfig(
+ model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100)
+
# Train until input_fn() is empty with Estimator. This is a prerequisite for
# TPU compatibility.
- estimator = tf.estimator.Estimator(model_fn=model_fn)
+ estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
estimator.train(input_fn=input_fn)
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
index cc48e3c69f..fe8e39c212 100644
--- a/tensorflow/contrib/kfac/python/ops/utils_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py
@@ -24,6 +24,7 @@ from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
+ "set_global_constants",
"SequenceDict",
"tensors_to_column",
"column_to_tensors",