aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-01 15:32:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-01 17:40:18 -0800
commitd719238f1ddedf5569bfd0ca13fa3a29bfecdd78 (patch)
tree92bb97c73f5ec3eb34208c4ea5d8155738b87916
parenta7398af84e30eda2cf47496c82bdfe1c9e36381d (diff)
Add iterate_batches arg to Estimator.predict
PiperOrigin-RevId: 184205196
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 8d59fe66d9..63d0f1e1d4 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -600,7 +600,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
input_fn=None,
batch_size=None,
outputs=None,
- as_iterable=True):
+ as_iterable=True,
+ iterate_batches=False):
"""Returns predictions for given features.
Args:
@@ -616,6 +617,9 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
for each example until inputs are exhausted. Note: The inputs must
terminate if you want the iterable to terminate (e.g. be sure to pass
num_epochs=1 if you are using something like read_batch_features).
+ iterate_batches: If True, yield the whole batch at once instead of
+ decomposing the batch into individual samples. Only relevant when
+ as_iterable is True.
Returns:
A numpy array of predicted classes or regression values if the
@@ -635,7 +639,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
input_fn=input_fn,
feed_fn=feed_fn,
outputs=outputs,
- as_iterable=as_iterable)
+ as_iterable=as_iterable,
+ iterate_batches=iterate_batches)
def get_variable_value(self, name):
"""Returns value of the variable given by name.