aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/python/ops/training_ops.py')
-rw-r--r--tensorflow/contrib/tensor_forest/python/ops/training_ops.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
index 8ca2491d60..5cf5e4af90 100644
--- a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
+++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py
@@ -25,6 +25,12 @@ import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('training_library_base_dir', '',
+ 'Directory to look for inference library file.')
+
TRAINING_OPS_FILE = '_training_ops.so'
_training_ops = None
@@ -101,7 +107,8 @@ def Load():
with _ops_lock:
global _training_ops
if not _training_ops:
- data_files_path = tf.resource_loader.get_data_files_path()
+ data_files_path = os.path.join(FLAGS.training_library_base_dir,
+ tf.resource_loader.get_data_files_path())
tf.logging.info('data path: %s', data_files_path)
_training_ops = tf.load_op_library(os.path.join(
data_files_path, TRAINING_OPS_FILE))