diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-11-08 11:15:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-08 16:24:25 -0800 |
commit | 29761a96990ca2188d5563933bac6206e3631852 (patch) | |
tree | 131f7a0300ebe9aa5c181ba381d6b80d94dafbd8 | |
parent | 3e205871f14932f60cbb995789c796843df7b5fd (diff) |
Let the DNNRegressor constructor accept an optional label_dimension argument.
Change: 138540603
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dnn.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/linear.py | 2 |
2 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index e417aa739f..6523b65fc2 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -678,7 +678,8 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor): gradient_clip_norm=None, enable_centered_bias=False, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + label_dimension=1): """Initializes a `DNNRegressor` instance. Args: @@ -711,6 +712,7 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor): labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + label_dimension: Dimension of the label for multilabels. Defaults to 1. Returns: A `DNNRegressor` estimator. @@ -726,7 +728,8 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor): gradient_clip_norm=gradient_clip_norm, enable_centered_bias=enable_centered_bias, config=config, - feature_engineering_fn=feature_engineering_fn) + feature_engineering_fn=feature_engineering_fn, + label_dimension=label_dimension) self.feature_columns = feature_columns self.optimizer = optimizer self.activation_fn = activation_fn diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 8d887f20c5..35b66bc317 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -628,7 +628,7 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable): enable_centered_bias: A bool. If True, estimator will learn a centered bias variable for each class. Rest of the model structure learns the residual after centered bias. - label_dimension: dimension of the label for multilabels. + label_dimension: Dimension of the label for multilabels. Defaults to 1. _joint_weights: If True use a single (possibly partitioned) variable to store the weights. It's faster, but requires all feature columns are sparse and have the 'sum' combiner. Incompatible with SDCAOptimizer. |