aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-15 13:56:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-15 14:04:35 -0800
commit12954acc1df5514c8193bb104e044432e34e5e60 (patch)
tree3fb5346169a33b0d7c28f9a0f70b3c2d93af3d6a
parentc7c709499795f3028fca222c3ce77a1fc34698b7 (diff)
Update SVM classifier to be consitent with other classifiers with respect to model_dir handling.
Change: 142184770
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/svm.py14
1 files changed, 5 insertions, 9 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py
index 561a898e78..a7e872b40f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/svm.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py
@@ -20,12 +20,9 @@ from __future__ import print_function
import inspect
import re
-import tempfile
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated_arg_values
-from tensorflow.contrib.framework import list_variables
-from tensorflow.contrib.framework import load_variable
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import trainable
@@ -149,11 +146,10 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
symmetric_l2_regularization=l2_regularization)
self._feature_columns = feature_columns
- self._model_dir = model_dir or tempfile.mkdtemp()
self._chief_hook = linear._SdcaUpdateWeightsHook() # pylint: disable=protected-access
self._estimator = estimator.Estimator(
model_fn=linear.sdca_model_fn,
- model_dir=self._model_dir,
+ model_dir=model_dir,
config=config,
params={
"head": head_lib._binary_svm_head( # pylint: disable=protected-access
@@ -229,7 +225,7 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
# pylint: enable=protected-access
def get_variable_names(self):
- return [name for name, _ in list_variables(self._model_dir)]
+ return self._estimator.get_variable_names()
def export(self, export_dir, signature_fn=None,
input_fn=None, default_batch_size=1,
@@ -264,15 +260,15 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
def weights_(self):
values = {}
optimizer_regex = r".*/"+self._optimizer.get_name() + r"(_\d)?$"
- for name, _ in list_variables(self._model_dir):
+ for name in self.get_variable_names():
if (name.startswith("linear/") and
name != "linear/bias_weight" and
not re.match(optimizer_regex, name)):
- values[name] = load_variable(self._model_dir, name)
+ values[name] = self.get_variable_value(name)
if len(values) == 1:
return values[list(values.keys())[0]]
return values
@property
def bias_(self):
- return load_variable(self._model_dir, name="linear/bias_weight")
+ return self.get_variable_value("linear/bias_weight")