diff options
Diffstat (limited to 'tensorflow/python/estimator/canned/linear.py')
-rw-r--r-- | tensorflow/python/estimator/canned/linear.py | 73 |
1 files changed, 37 insertions, 36 deletions
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index 05b1e5b44a..552b1bdf01 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -42,8 +42,8 @@ def _get_default_optimizer(feature_columns): return ftrl.FtrlOptimizer(learning_rate=learning_rate) -def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, - partitioner, config): +# TODO(b/36813849): Revisit passing params vs named arguments. +def _linear_model_fn(features, labels, mode, params, config): """A model_fn for linear models that use a gradient-based optimizer. Args: @@ -51,12 +51,13 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, labels: `Tensor` of shape `[batch_size, logits_dimension]`. mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. - head: A `Head` instance. - feature_columns: An iterable containing all the feature columns used by - the model. - optimizer: string, `Optimizer` object, or callable that defines the - optimizer to use for training. If `None`, will use a FTRL optimizer. - partitioner: Partitioner for variables. + params: A dict of hyperparameters. + The following hyperparameters are expected: + * head: A `Head` instance. + * feature_columns: An iterable containing all the feature columns used by + the model. + * optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training. If `None`, will use a FTRL optimizer. config: `RunConfig` object to configure the runtime settings. Returns: @@ -65,12 +66,14 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, Raises: ValueError: If mode or params are invalid. """ + head = params['head'] + feature_columns = tuple(params['feature_columns']) optimizer = optimizers.get_optimizer_instance( - optimizer or _get_default_optimizer(feature_columns), + params.get('optimizer') or _get_default_optimizer(feature_columns), learning_rate=_LEARNING_RATE) num_ps_replicas = config.num_ps_replicas if config else 0 - partitioner = partitioner or ( + partitioner = params.get('partitioner') or ( partitioned_variables.min_max_variable_partitioner( max_partitions=num_ps_replicas, min_slice_size=64 << 20)) @@ -207,20 +210,16 @@ class LinearClassifier(estimator.Estimator): head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access n_classes, weight_column=weight_column, label_vocabulary=label_vocabulary) - def _model_fn(features, labels, mode, config): - return _linear_model_fn( - features=features, - labels=labels, - mode=mode, - head=head, - feature_columns=tuple(feature_columns or []), - optimizer=optimizer, - partitioner=partitioner, - config=config) super(LinearClassifier, self).__init__( - model_fn=_model_fn, + model_fn=_linear_model_fn, model_dir=model_dir, - config=config) + config=config, + params={ + 'head': head, + 'feature_columns': feature_columns, + 'optimizer': optimizer, + 'partitioner': partitioner, + }) class LinearRegressor(estimator.Estimator): @@ -299,19 +298,21 @@ class LinearRegressor(estimator.Estimator): config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. """ - head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access - label_dimension=label_dimension, weight_column=weight_column) - def _model_fn(features, labels, mode, config): - return _linear_model_fn( - features=features, - labels=labels, - mode=mode, - head=head, - feature_columns=tuple(feature_columns or []), - optimizer=optimizer, - partitioner=partitioner, - config=config) super(LinearRegressor, self).__init__( - model_fn=_model_fn, + model_fn=_linear_model_fn, model_dir=model_dir, - config=config) + config=config, + params={ + # pylint: disable=protected-access + 'head': + head_lib._regression_head_with_mean_squared_error_loss( + label_dimension=label_dimension, + weight_column=weight_column), + # pylint: enable=protected-access + 'feature_columns': + feature_columns, + 'optimizer': + optimizer, + 'partitioner': + partitioner, + }) |