aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
blob: 9f8cea375b7f7484eb5ba446f9e715d86bcf810b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""API wrapper allowing to use certain Keras models with the Scikit-Learn API.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import types

import numpy as np

from tensorflow.contrib.keras.python.keras.models import Sequential
from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical
from tensorflow.python.util import tf_inspect


class BaseWrapper(object):
  """Base class for the Keras scikit-learn wrapper.

  Warning: This class should not be used directly.
  Use descendant classes instead.

  Arguments:
      build_fn: callable function or class instance
      **sk_params: model parameters & fitting parameters

  The build_fn should construct, compile and return a Keras model, which
  will then be used to fit/predict. One of the following
  three values could be passed to build_fn:
  1. A function
  2. An instance of a class that implements the __call__ method
  3. None. This means you implement a class that inherits from either
  `KerasClassifier` or `KerasRegressor`. The __call__ method of the
  present class will then be treated as the default build_fn.

  `sk_params` takes both model parameters and fitting parameters. Legal model
  parameters are the arguments of `build_fn`. Note that like all other
  estimators in scikit-learn, 'build_fn' should provide default values for
  its arguments, so that you could create the estimator without passing any
  values to `sk_params`.

  `sk_params` could also accept parameters for calling `fit`, `predict`,
  `predict_proba`, and `score` methods (e.g., `epochs`, `batch_size`).
  fitting (predicting) parameters are selected in the following order:

  1. Values passed to the dictionary arguments of
  `fit`, `predict`, `predict_proba`, and `score` methods
  2. Values passed to `sk_params`
  3. The default values of the `keras.models.Sequential`
  `fit`, `predict`, `predict_proba` and `score` methods

  When using scikit-learn's `grid_search` API, legal tunable parameters are
  those you could pass to `sk_params`, including fitting parameters.
  In other words, you could use `grid_search` to search for the best
  `batch_size` or `epochs` as well as the model parameters.
  """

  def __init__(self, build_fn=None, **sk_params):
    self.build_fn = build_fn
    self.sk_params = sk_params
    self.check_params(sk_params)

  def check_params(self, params):
    """Checks for user typos in "params".

    Arguments:
        params: dictionary; the parameters to be checked

    Raises:
        ValueError: if any member of `params` is not a valid argument.
    """
    legal_params_fns = [
        Sequential.fit, Sequential.predict, Sequential.predict_classes,
        Sequential.evaluate
    ]
    if self.build_fn is None:
      legal_params_fns.append(self.__call__)
    elif (not isinstance(self.build_fn, types.FunctionType) and
          not isinstance(self.build_fn, types.MethodType)):
      legal_params_fns.append(self.build_fn.__call__)
    else:
      legal_params_fns.append(self.build_fn)

    legal_params = []
    for fn in legal_params_fns:
      legal_params += tf_inspect.getargspec(fn)[0]
    legal_params = set(legal_params)

    for params_name in params:
      if params_name not in legal_params:
        if params_name != 'nb_epoch':
          raise ValueError('{} is not a legal parameter'.format(params_name))

  def get_params(self, **params):  # pylint: disable=unused-argument
    """Gets parameters for this estimator.

    Arguments:
        **params: ignored (exists for API compatiblity).

    Returns:
        Dictionary of parameter names mapped to their values.
    """
    res = copy.deepcopy(self.sk_params)
    res.update({'build_fn': self.build_fn})
    return res

  def set_params(self, **params):
    """Sets the parameters of this estimator.

    Arguments:
        **params: Dictionary of parameter names mapped to their values.

    Returns:
        self
    """
    self.check_params(params)
    self.sk_params.update(params)
    return self

  def fit(self, x, y, **kwargs):
    """Constructs a new model with `build_fn` & fit the model to `(x, y)`.

    Arguments:
        x : array-like, shape `(n_samples, n_features)`
            Training samples where n_samples in the number of samples
            and n_features is the number of features.
        y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
            True labels for X.
        **kwargs: dictionary arguments
            Legal arguments are the arguments of `Sequential.fit`

    Returns:
        history : object
            details about the training history at each epoch.
    """
    if self.build_fn is None:
      self.model = self.__call__(**self.filter_sk_params(self.__call__))
    elif (not isinstance(self.build_fn, types.FunctionType) and
          not isinstance(self.build_fn, types.MethodType)):
      self.model = self.build_fn(
          **self.filter_sk_params(self.build_fn.__call__))
    else:
      self.model = self.build_fn(**self.filter_sk_params(self.build_fn))

    loss_name = self.model.loss
    if hasattr(loss_name, '__name__'):
      loss_name = loss_name.__name__
    if loss_name == 'categorical_crossentropy' and len(y.shape) != 2:
      y = to_categorical(y)

    fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit))
    fit_args.update(kwargs)

    history = self.model.fit(x, y, **fit_args)

    return history

  def filter_sk_params(self, fn, override=None):
    """Filters `sk_params` and return those in `fn`'s arguments.

    Arguments:
        fn : arbitrary function
        override: dictionary, values to override sk_params

    Returns:
        res : dictionary dictionary containing variables
            in both sk_params and fn's arguments.
    """
    override = override or {}
    res = {}
    fn_args = tf_inspect.getargspec(fn)[0]
    for name, value in self.sk_params.items():
      if name in fn_args:
        res.update({name: value})
    res.update(override)
    return res


class KerasClassifier(BaseWrapper):
  """Implementation of the scikit-learn classifier API for Keras.
  """

  def fit(self, x, y, **kwargs):
    """Constructs a new model with `build_fn` & fit the model to `(x, y)`.

    Arguments:
        x : array-like, shape `(n_samples, n_features)`
            Training samples where n_samples in the number of samples
            and n_features is the number of features.
        y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
            True labels for X.
        **kwargs: dictionary arguments
            Legal arguments are the arguments of `Sequential.fit`

    Returns:
        history : object
            details about the training history at each epoch.

    Raises:
        ValueError: In case of invalid shape for `y` argument.
    """
    y = np.array(y)
    if len(y.shape) == 2 and y.shape[1] > 1:
      self.classes_ = np.arange(y.shape[1])
    elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1:
      self.classes_ = np.unique(y)
      y = np.searchsorted(self.classes_, y)
    else:
      raise ValueError('Invalid shape for y: ' + str(y.shape))
    self.n_classes_ = len(self.classes_)
    return super(KerasClassifier, self).fit(x, y, **kwargs)

  def predict(self, x, **kwargs):
    """Returns the class predictions for the given test data.

    Arguments:
        x: array-like, shape `(n_samples, n_features)`
            Test samples where n_samples in the number of samples
            and n_features is the number of features.
        **kwargs: dictionary arguments
            Legal arguments are the arguments
            of `Sequential.predict_classes`.

    Returns:
        preds: array-like, shape `(n_samples,)`
            Class predictions.
    """
    kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
    classes = self.model.predict_classes(x, **kwargs)
    return self.classes_[classes]

  def predict_proba(self, x, **kwargs):
    """Returns class probability estimates for the given test data.

    Arguments:
        x: array-like, shape `(n_samples, n_features)`
            Test samples where n_samples in the number of samples
            and n_features is the number of features.
        **kwargs: dictionary arguments
            Legal arguments are the arguments
            of `Sequential.predict_classes`.

    Returns:
        proba: array-like, shape `(n_samples, n_outputs)`
            Class probability estimates.
            In the case of binary classification,
            tp match the scikit-learn API,
            will return an array of shape '(n_samples, 2)'
            (instead of `(n_sample, 1)` as in Keras).
    """
    kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
    probs = self.model.predict_proba(x, **kwargs)

    # check if binary classification
    if probs.shape[1] == 1:
      # first column is probability of class 0 and second is of class 1
      probs = np.hstack([1 - probs, probs])
    return probs

  def score(self, x, y, **kwargs):
    """Returns the mean accuracy on the given test data and labels.

    Arguments:
        x: array-like, shape `(n_samples, n_features)`
            Test samples where n_samples in the number of samples
            and n_features is the number of features.
        y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
            True labels for x.
        **kwargs: dictionary arguments
            Legal arguments are the arguments of `Sequential.evaluate`.

    Returns:
        score: float
            Mean accuracy of predictions on X wrt. y.

    Raises:
        ValueError: If the underlying model isn't configured to
            compute accuracy. You should pass `metrics=["accuracy"]` to
            the `.compile()` method of the model.
    """
    y = np.searchsorted(self.classes_, y)
    kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)

    loss_name = self.model.loss
    if hasattr(loss_name, '__name__'):
      loss_name = loss_name.__name__
    if loss_name == 'categorical_crossentropy' and len(y.shape) != 2:
      y = to_categorical(y)

    outputs = self.model.evaluate(x, y, **kwargs)
    if not isinstance(outputs, list):
      outputs = [outputs]
    for name, output in zip(self.model.metrics_names, outputs):
      if name == 'acc':
        return output
    raise ValueError('The model is not configured to compute accuracy. '
                     'You should pass `metrics=["accuracy"]` to '
                     'the `model.compile()` method.')


class KerasRegressor(BaseWrapper):
  """Implementation of the scikit-learn regressor API for Keras.
  """

  def predict(self, x, **kwargs):
    """Returns predictions for the given test data.

    Arguments:
        x: array-like, shape `(n_samples, n_features)`
            Test samples where n_samples in the number of samples
            and n_features is the number of features.
        **kwargs: dictionary arguments
            Legal arguments are the arguments of `Sequential.predict`.

    Returns:
        preds: array-like, shape `(n_samples,)`
            Predictions.
    """
    kwargs = self.filter_sk_params(Sequential.predict, kwargs)
    return np.squeeze(self.model.predict(x, **kwargs))

  def score(self, x, y, **kwargs):
    """Returns the mean loss on the given test data and labels.

    Arguments:
        x: array-like, shape `(n_samples, n_features)`
            Test samples where n_samples in the number of samples
            and n_features is the number of features.
        y: array-like, shape `(n_samples,)`
            True labels for X.
        **kwargs: dictionary arguments
            Legal arguments are the arguments of `Sequential.evaluate`.

    Returns:
        score: float
            Mean accuracy of predictions on X wrt. y.
    """
    kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
    loss = self.model.evaluate(x, y, **kwargs)
    if isinstance(loss, list):
      return loss[0]
    return loss