diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/trainable.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/trainable.py | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/tensorflow/contrib/learn/python/learn/trainable.py b/tensorflow/contrib/learn/python/learn/trainable.py index 8a1548738e..2d1d460425 100644 --- a/tensorflow/contrib/learn/python/learn/trainable.py +++ b/tensorflow/contrib/learn/python/learn/trainable.py @@ -33,17 +33,17 @@ class Trainable(object): """Trains a model given training data `x` predictions and `y` labels. Args: - x: Matrix of shape [n_samples, n_features...]. Can be iterator that - returns arrays of features. The training input samples for fitting the - model. If set, `input_fn` must be `None`. - y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be - iterator that returns array of labels. The training label values - (class labels in classification, real numbers in regression). If set, - `input_fn` must be `None`. Note: For classification, label values must + x: Matrix of shape [n_samples, n_features...] or the dictionary of Matrices. + Can be iterator that returns arrays of features or dictionary of arrays of features. + The training input samples for fitting the model. If set, `input_fn` must be `None`. + y: Vector or matrix [n_samples] or [n_samples, n_outputs] or the dictionary of same. + Can be iterator that returns array of labels or dictionary of array of labels. + The training label values (class labels in classification, real numbers in regression). + If set, `input_fn` must be `None`. Note: For classification, label values must be integers representing the class index (i.e. values from 0 to n_classes-1). input_fn: Input function returning a tuple of: - features - Dictionary of string feature name to `Tensor` or `Tensor`. + features - `Tensor` or dictionary of string feature name to `Tensor`. labels - `Tensor` or dictionary of `Tensor` with labels. If input_fn is set, `x`, `y`, and `batch_size` must be `None`. steps: Number of steps for which to train model. If `None`, train forever. @@ -67,4 +67,3 @@ class Trainable(object): `self`, for chaining. """ raise NotImplementedError - |