aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-08 11:15:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-08 16:24:25 -0800
commit29761a96990ca2188d5563933bac6206e3631852 (patch)
tree131f7a0300ebe9aa5c181ba381d6b80d94dafbd8
parent3e205871f14932f60cbb995789c796843df7b5fd (diff)
Let the DNNRegressor constructor accept an optional label_dimension argument.
Change: 138540603
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn.py7
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py2
2 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
index e417aa739f..6523b65fc2 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
@@ -678,7 +678,8 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
gradient_clip_norm=None,
enable_centered_bias=False,
config=None,
- feature_engineering_fn=None):
+ feature_engineering_fn=None,
+ label_dimension=1):
"""Initializes a `DNNRegressor` instance.
Args:
@@ -711,6 +712,7 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
labels which are the output of `input_fn` and
returns features and labels which will be fed
into the model.
+ label_dimension: Dimension of the label for multilabels. Defaults to 1.
Returns:
A `DNNRegressor` estimator.
@@ -726,7 +728,8 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
gradient_clip_norm=gradient_clip_norm,
enable_centered_bias=enable_centered_bias,
config=config,
- feature_engineering_fn=feature_engineering_fn)
+ feature_engineering_fn=feature_engineering_fn,
+ label_dimension=label_dimension)
self.feature_columns = feature_columns
self.optimizer = optimizer
self.activation_fn = activation_fn
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index 8d887f20c5..35b66bc317 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -628,7 +628,7 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
enable_centered_bias: A bool. If True, estimator will learn a centered
bias variable for each class. Rest of the model structure learns the
residual after centered bias.
- label_dimension: dimension of the label for multilabels.
+ label_dimension: Dimension of the label for multilabels. Defaults to 1.
_joint_weights: If True use a single (possibly partitioned) variable to
store the weights. It's faster, but requires all feature columns are
sparse and have the 'sum' combiner. Incompatible with SDCAOptimizer.