diff options
author | 2017-06-22 17:27:03 -0700 | |
---|---|---|
committer | 2017-06-22 17:30:42 -0700 | |
commit | e6b08c491063e0aa0485fa5db2c6332af1519b7f (patch) | |
tree | 0d9c7e39d95314e7ebb1c089ad46a0f74800d8a1 | |
parent | baf2bf53afc730108a0669f9310126f55ca45650 (diff) |
Alligned how model-fns handled params among linear/dnn/combined estimators.
PiperOrigin-RevId: 159899925
-rw-r--r-- | tensorflow/python/estimator/canned/linear.py | 73 |
1 files changed, 36 insertions, 37 deletions
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index 552b1bdf01..05b1e5b44a 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -42,8 +42,8 @@ def _get_default_optimizer(feature_columns): return ftrl.FtrlOptimizer(learning_rate=learning_rate) -# TODO(b/36813849): Revisit passing params vs named arguments. -def _linear_model_fn(features, labels, mode, params, config): +def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, + partitioner, config): """A model_fn for linear models that use a gradient-based optimizer. Args: @@ -51,13 +51,12 @@ def _linear_model_fn(features, labels, mode, params, config): labels: `Tensor` of shape `[batch_size, logits_dimension]`. mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. - params: A dict of hyperparameters. - The following hyperparameters are expected: - * head: A `Head` instance. - * feature_columns: An iterable containing all the feature columns used by - the model. - * optimizer: string, `Optimizer` object, or callable that defines the - optimizer to use for training. If `None`, will use a FTRL optimizer. + head: A `Head` instance. + feature_columns: An iterable containing all the feature columns used by + the model. + optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training. If `None`, will use a FTRL optimizer. + partitioner: Partitioner for variables. config: `RunConfig` object to configure the runtime settings. Returns: @@ -66,14 +65,12 @@ def _linear_model_fn(features, labels, mode, params, config): Raises: ValueError: If mode or params are invalid. """ - head = params['head'] - feature_columns = tuple(params['feature_columns']) optimizer = optimizers.get_optimizer_instance( - params.get('optimizer') or _get_default_optimizer(feature_columns), + optimizer or _get_default_optimizer(feature_columns), learning_rate=_LEARNING_RATE) num_ps_replicas = config.num_ps_replicas if config else 0 - partitioner = params.get('partitioner') or ( + partitioner = partitioner or ( partitioned_variables.min_max_variable_partitioner( max_partitions=num_ps_replicas, min_slice_size=64 << 20)) @@ -210,16 +207,20 @@ class LinearClassifier(estimator.Estimator): head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access n_classes, weight_column=weight_column, label_vocabulary=label_vocabulary) + def _model_fn(features, labels, mode, config): + return _linear_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + feature_columns=tuple(feature_columns or []), + optimizer=optimizer, + partitioner=partitioner, + config=config) super(LinearClassifier, self).__init__( - model_fn=_linear_model_fn, + model_fn=_model_fn, model_dir=model_dir, - config=config, - params={ - 'head': head, - 'feature_columns': feature_columns, - 'optimizer': optimizer, - 'partitioner': partitioner, - }) + config=config) class LinearRegressor(estimator.Estimator): @@ -298,21 +299,19 @@ class LinearRegressor(estimator.Estimator): config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. """ + head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access + label_dimension=label_dimension, weight_column=weight_column) + def _model_fn(features, labels, mode, config): + return _linear_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + feature_columns=tuple(feature_columns or []), + optimizer=optimizer, + partitioner=partitioner, + config=config) super(LinearRegressor, self).__init__( - model_fn=_linear_model_fn, + model_fn=_model_fn, model_dir=model_dir, - config=config, - params={ - # pylint: disable=protected-access - 'head': - head_lib._regression_head_with_mean_squared_error_loss( - label_dimension=label_dimension, - weight_column=weight_column), - # pylint: enable=protected-access - 'feature_columns': - feature_columns, - 'optimizer': - optimizer, - 'partitioner': - partitioner, - }) + config=config) |