aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/trainable.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/trainable.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/trainable.py17
1 files changed, 8 insertions, 9 deletions
diff --git a/tensorflow/contrib/learn/python/learn/trainable.py b/tensorflow/contrib/learn/python/learn/trainable.py
index 8a1548738e..2d1d460425 100644
--- a/tensorflow/contrib/learn/python/learn/trainable.py
+++ b/tensorflow/contrib/learn/python/learn/trainable.py
@@ -33,17 +33,17 @@ class Trainable(object):
"""Trains a model given training data `x` predictions and `y` labels.
Args:
- x: Matrix of shape [n_samples, n_features...]. Can be iterator that
- returns arrays of features. The training input samples for fitting the
- model. If set, `input_fn` must be `None`.
- y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
- iterator that returns array of labels. The training label values
- (class labels in classification, real numbers in regression). If set,
- `input_fn` must be `None`. Note: For classification, label values must
+ x: Matrix of shape [n_samples, n_features...] or the dictionary of Matrices.
+ Can be iterator that returns arrays of features or dictionary of arrays of features.
+ The training input samples for fitting the model. If set, `input_fn` must be `None`.
+ y: Vector or matrix [n_samples] or [n_samples, n_outputs] or the dictionary of same.
+ Can be iterator that returns array of labels or dictionary of array of labels.
+ The training label values (class labels in classification, real numbers in regression).
+ If set, `input_fn` must be `None`. Note: For classification, label values must
be integers representing the class index (i.e. values from 0 to
n_classes-1).
input_fn: Input function returning a tuple of:
- features - Dictionary of string feature name to `Tensor` or `Tensor`.
+ features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
If input_fn is set, `x`, `y`, and `batch_size` must be `None`.
steps: Number of steps for which to train model. If `None`, train forever.
@@ -67,4 +67,3 @@ class Trainable(object):
`self`, for chaining.
"""
raise NotImplementedError
-