aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-06-22 17:27:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-22 17:30:42 -0700
commite6b08c491063e0aa0485fa5db2c6332af1519b7f (patch)
tree0d9c7e39d95314e7ebb1c089ad46a0f74800d8a1
parentbaf2bf53afc730108a0669f9310126f55ca45650 (diff)
Alligned how model-fns handled params among linear/dnn/combined estimators.
PiperOrigin-RevId: 159899925
-rw-r--r--tensorflow/python/estimator/canned/linear.py73
1 files changed, 36 insertions, 37 deletions
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 552b1bdf01..05b1e5b44a 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -42,8 +42,8 @@ def _get_default_optimizer(feature_columns):
return ftrl.FtrlOptimizer(learning_rate=learning_rate)
-# TODO(b/36813849): Revisit passing params vs named arguments.
-def _linear_model_fn(features, labels, mode, params, config):
+def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
+ partitioner, config):
"""A model_fn for linear models that use a gradient-based optimizer.
Args:
@@ -51,13 +51,12 @@ def _linear_model_fn(features, labels, mode, params, config):
labels: `Tensor` of shape `[batch_size, logits_dimension]`.
mode: Defines whether this is training, evaluation or prediction.
See `ModeKeys`.
- params: A dict of hyperparameters.
- The following hyperparameters are expected:
- * head: A `Head` instance.
- * feature_columns: An iterable containing all the feature columns used by
- the model.
- * optimizer: string, `Optimizer` object, or callable that defines the
- optimizer to use for training. If `None`, will use a FTRL optimizer.
+ head: A `Head` instance.
+ feature_columns: An iterable containing all the feature columns used by
+ the model.
+ optimizer: string, `Optimizer` object, or callable that defines the
+ optimizer to use for training. If `None`, will use a FTRL optimizer.
+ partitioner: Partitioner for variables.
config: `RunConfig` object to configure the runtime settings.
Returns:
@@ -66,14 +65,12 @@ def _linear_model_fn(features, labels, mode, params, config):
Raises:
ValueError: If mode or params are invalid.
"""
- head = params['head']
- feature_columns = tuple(params['feature_columns'])
optimizer = optimizers.get_optimizer_instance(
- params.get('optimizer') or _get_default_optimizer(feature_columns),
+ optimizer or _get_default_optimizer(feature_columns),
learning_rate=_LEARNING_RATE)
num_ps_replicas = config.num_ps_replicas if config else 0
- partitioner = params.get('partitioner') or (
+ partitioner = partitioner or (
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
@@ -210,16 +207,20 @@ class LinearClassifier(estimator.Estimator):
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
n_classes, weight_column=weight_column,
label_vocabulary=label_vocabulary)
+ def _model_fn(features, labels, mode, config):
+ return _linear_model_fn(
+ features=features,
+ labels=labels,
+ mode=mode,
+ head=head,
+ feature_columns=tuple(feature_columns or []),
+ optimizer=optimizer,
+ partitioner=partitioner,
+ config=config)
super(LinearClassifier, self).__init__(
- model_fn=_linear_model_fn,
+ model_fn=_model_fn,
model_dir=model_dir,
- config=config,
- params={
- 'head': head,
- 'feature_columns': feature_columns,
- 'optimizer': optimizer,
- 'partitioner': partitioner,
- })
+ config=config)
class LinearRegressor(estimator.Estimator):
@@ -298,21 +299,19 @@ class LinearRegressor(estimator.Estimator):
config: `RunConfig` object to configure the runtime settings.
partitioner: Optional. Partitioner for input layer.
"""
+ head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access
+ label_dimension=label_dimension, weight_column=weight_column)
+ def _model_fn(features, labels, mode, config):
+ return _linear_model_fn(
+ features=features,
+ labels=labels,
+ mode=mode,
+ head=head,
+ feature_columns=tuple(feature_columns or []),
+ optimizer=optimizer,
+ partitioner=partitioner,
+ config=config)
super(LinearRegressor, self).__init__(
- model_fn=_linear_model_fn,
+ model_fn=_model_fn,
model_dir=model_dir,
- config=config,
- params={
- # pylint: disable=protected-access
- 'head':
- head_lib._regression_head_with_mean_squared_error_loss(
- label_dimension=label_dimension,
- weight_column=weight_column),
- # pylint: enable=protected-access
- 'feature_columns':
- feature_columns,
- 'optimizer':
- optimizer,
- 'partitioner':
- partitioner,
- })
+ config=config)