diff options
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/ops/training_ops.py')
-rw-r--r-- | tensorflow/contrib/tensor_forest/python/ops/training_ops.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py index 8ca2491d60..5cf5e4af90 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py @@ -25,6 +25,12 @@ import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +flags = tf.app.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string('training_library_base_dir', '', + 'Directory to look for inference library file.') + TRAINING_OPS_FILE = '_training_ops.so' _training_ops = None @@ -101,7 +107,8 @@ def Load(): with _ops_lock: global _training_ops if not _training_ops: - data_files_path = tf.resource_loader.get_data_files_path() + data_files_path = os.path.join(FLAGS.training_library_base_dir, + tf.resource_loader.get_data_files_path()) tf.logging.info('data path: %s', data_files_path) _training_ops = tf.load_op_library(os.path.join( data_files_path, TRAINING_OPS_FILE)) |