From 6558e37454de83652b5a9c5beb8f9230faecc7be Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 30 Jan 2018 11:47:43 -0800 Subject: K-FAC: expose set_global_constants() for tf.contrib.kfac.utils PiperOrigin-RevId: 183867014 --- tensorflow/contrib/kfac/examples/mlp.py | 5 ++++- tensorflow/contrib/kfac/python/ops/utils_lib.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py index 0f0dbb53f4..87eed03888 100644 --- a/tensorflow/contrib/kfac/examples/mlp.py +++ b/tensorflow/contrib/kfac/examples/mlp.py @@ -317,7 +317,10 @@ def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=hooks) + run_config = tf.estimator.RunConfig( + model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100) + # Train until input_fn() is empty with Estimator. This is a prerequisite for # TPU compatibility. - estimator = tf.estimator.Estimator(model_fn=model_fn) + estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) estimator.train(input_fn=input_fn) diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index cc48e3c69f..fe8e39c212 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -24,6 +24,7 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ + "set_global_constants", "SequenceDict", "tensors_to_column", "column_to_tensors", -- cgit v1.2.3