aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/kfac/examples/mlp.py5
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py1
2 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py
index 0f0dbb53f4..87eed03888 100644
--- a/tensorflow/contrib/kfac/examples/mlp.py
+++ b/tensorflow/contrib/kfac/examples/mlp.py
@@ -317,7 +317,10 @@ def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)
+ run_config = tf.estimator.RunConfig(
+ model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100)
+
# Train until input_fn() is empty with Estimator. This is a prerequisite for
# TPU compatibility.
- estimator = tf.estimator.Estimator(model_fn=model_fn)
+ estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
estimator.train(input_fn=input_fn)
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
index cc48e3c69f..fe8e39c212 100644
--- a/tensorflow/contrib/kfac/python/ops/utils_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py
@@ -24,6 +24,7 @@ from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
+ "set_global_constants",
"SequenceDict",
"tensors_to_column",
"column_to_tensors",