aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/canned/linear.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/canned/linear.py')
-rw-r--r--tensorflow/python/estimator/canned/linear.py73
1 files changed, 37 insertions, 36 deletions
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 05b1e5b44a..552b1bdf01 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -42,8 +42,8 @@ def _get_default_optimizer(feature_columns):
return ftrl.FtrlOptimizer(learning_rate=learning_rate)
-def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
- partitioner, config):
+# TODO(b/36813849): Revisit passing params vs named arguments.
+def _linear_model_fn(features, labels, mode, params, config):
"""A model_fn for linear models that use a gradient-based optimizer.
Args:
@@ -51,12 +51,13 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
labels: `Tensor` of shape `[batch_size, logits_dimension]`.
mode: Defines whether this is training, evaluation or prediction.
See `ModeKeys`.
- head: A `Head` instance.
- feature_columns: An iterable containing all the feature columns used by
- the model.
- optimizer: string, `Optimizer` object, or callable that defines the
- optimizer to use for training. If `None`, will use a FTRL optimizer.
- partitioner: Partitioner for variables.
+ params: A dict of hyperparameters.
+ The following hyperparameters are expected:
+ * head: A `Head` instance.
+ * feature_columns: An iterable containing all the feature columns used by
+ the model.
+ * optimizer: string, `Optimizer` object, or callable that defines the
+ optimizer to use for training. If `None`, will use a FTRL optimizer.
config: `RunConfig` object to configure the runtime settings.
Returns:
@@ -65,12 +66,14 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
Raises:
ValueError: If mode or params are invalid.
"""
+ head = params['head']
+ feature_columns = tuple(params['feature_columns'])
optimizer = optimizers.get_optimizer_instance(
- optimizer or _get_default_optimizer(feature_columns),
+ params.get('optimizer') or _get_default_optimizer(feature_columns),
learning_rate=_LEARNING_RATE)
num_ps_replicas = config.num_ps_replicas if config else 0
- partitioner = partitioner or (
+ partitioner = params.get('partitioner') or (
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
@@ -207,20 +210,16 @@ class LinearClassifier(estimator.Estimator):
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
n_classes, weight_column=weight_column,
label_vocabulary=label_vocabulary)
- def _model_fn(features, labels, mode, config):
- return _linear_model_fn(
- features=features,
- labels=labels,
- mode=mode,
- head=head,
- feature_columns=tuple(feature_columns or []),
- optimizer=optimizer,
- partitioner=partitioner,
- config=config)
super(LinearClassifier, self).__init__(
- model_fn=_model_fn,
+ model_fn=_linear_model_fn,
model_dir=model_dir,
- config=config)
+ config=config,
+ params={
+ 'head': head,
+ 'feature_columns': feature_columns,
+ 'optimizer': optimizer,
+ 'partitioner': partitioner,
+ })
class LinearRegressor(estimator.Estimator):
@@ -299,19 +298,21 @@ class LinearRegressor(estimator.Estimator):
config: `RunConfig` object to configure the runtime settings.
partitioner: Optional. Partitioner for input layer.
"""
- head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access
- label_dimension=label_dimension, weight_column=weight_column)
- def _model_fn(features, labels, mode, config):
- return _linear_model_fn(
- features=features,
- labels=labels,
- mode=mode,
- head=head,
- feature_columns=tuple(feature_columns or []),
- optimizer=optimizer,
- partitioner=partitioner,
- config=config)
super(LinearRegressor, self).__init__(
- model_fn=_model_fn,
+ model_fn=_linear_model_fn,
model_dir=model_dir,
- config=config)
+ config=config,
+ params={
+ # pylint: disable=protected-access
+ 'head':
+ head_lib._regression_head_with_mean_squared_error_loss(
+ label_dimension=label_dimension,
+ weight_column=weight_column),
+ # pylint: enable=protected-access
+ 'feature_columns':
+ feature_columns,
+ 'optimizer':
+ optimizer,
+ 'partitioner':
+ partitioner,
+ })