diff options
author | 2018-02-01 15:32:20 -0800 | |
---|---|---|
committer | 2018-02-01 17:40:18 -0800 | |
commit | d719238f1ddedf5569bfd0ca13fa3a29bfecdd78 (patch) | |
tree | 92bb97c73f5ec3eb34208c4ea5d8155738b87916 | |
parent | a7398af84e30eda2cf47496c82bdfe1c9e36381d (diff) |
Add iterate_batches arg to Estimator.predict
PiperOrigin-RevId: 184205196
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimator.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 8d59fe66d9..63d0f1e1d4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -600,7 +600,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, input_fn=None, batch_size=None, outputs=None, - as_iterable=True): + as_iterable=True, + iterate_batches=False): """Returns predictions for given features. Args: @@ -616,6 +617,9 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, for each example until inputs are exhausted. Note: The inputs must terminate if you want the iterable to terminate (e.g. be sure to pass num_epochs=1 if you are using something like read_batch_features). + iterate_batches: If True, yield the whole batch at once instead of + decomposing the batch into individual samples. Only relevant when + as_iterable is True. Returns: A numpy array of predicted classes or regression values if the @@ -635,7 +639,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, input_fn=input_fn, feed_fn=feed_fn, outputs=outputs, - as_iterable=as_iterable) + as_iterable=as_iterable, + iterate_batches=iterate_batches) def get_variable_value(self, name): """Returns value of the variable given by name. |