diff options
author | 2016-12-15 13:56:39 -0800 | |
---|---|---|
committer | 2016-12-15 14:04:35 -0800 | |
commit | 12954acc1df5514c8193bb104e044432e34e5e60 (patch) | |
tree | 3fb5346169a33b0d7c28f9a0f70b3c2d93af3d6a | |
parent | c7c709499795f3028fca222c3ce77a1fc34698b7 (diff) |
Update SVM classifier to be consitent with other classifiers with respect to model_dir handling.
Change: 142184770
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/svm.py | 14 |
1 files changed, 5 insertions, 9 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py index 561a898e78..a7e872b40f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/svm.py +++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py @@ -20,12 +20,9 @@ from __future__ import print_function import inspect import re -import tempfile from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.contrib.framework import list_variables -from tensorflow.contrib.framework import load_variable from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import trainable @@ -149,11 +146,10 @@ class SVM(trainable.Trainable, evaluable.Evaluable): symmetric_l2_regularization=l2_regularization) self._feature_columns = feature_columns - self._model_dir = model_dir or tempfile.mkdtemp() self._chief_hook = linear._SdcaUpdateWeightsHook() # pylint: disable=protected-access self._estimator = estimator.Estimator( model_fn=linear.sdca_model_fn, - model_dir=self._model_dir, + model_dir=model_dir, config=config, params={ "head": head_lib._binary_svm_head( # pylint: disable=protected-access @@ -229,7 +225,7 @@ class SVM(trainable.Trainable, evaluable.Evaluable): # pylint: enable=protected-access def get_variable_names(self): - return [name for name, _ in list_variables(self._model_dir)] + return self._estimator.get_variable_names() def export(self, export_dir, signature_fn=None, input_fn=None, default_batch_size=1, @@ -264,15 +260,15 @@ class SVM(trainable.Trainable, evaluable.Evaluable): def weights_(self): values = {} optimizer_regex = r".*/"+self._optimizer.get_name() + r"(_\d)?$" - for name, _ in list_variables(self._model_dir): + for name in self.get_variable_names(): if (name.startswith("linear/") and name != "linear/bias_weight" and not re.match(optimizer_regex, name)): - values[name] = load_variable(self._model_dir, name) + values[name] = self.get_variable_value(name) if len(values) == 1: return values[list(values.keys())[0]] return values @property def bias_(self): - return load_variable(self._model_dir, name="linear/bias_weight") + return self.get_variable_value("linear/bias_weight") |