aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/skflow/iris_val_based_early_stopping.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-08 09:06:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-08 10:17:45 -0700
commitc00c073f52c2fc7b6672022c75d0b2abb9d9af3a (patch)
treed5aad86afbf6697bcb6eaffc50c1c1f6f48cd0d0 /tensorflow/examples/skflow/iris_val_based_early_stopping.py
parent9a9219be3531d12c804f671e3e236a0d05c01d70 (diff)
Begin removing feature column inference from linear and dnn estimators. Currently, the fit operation of each of them will infer feature columns from the passed in features. But it only works for dense float inputs.
Also, fixed some lint warnings. Change: 126921818
Diffstat (limited to 'tensorflow/examples/skflow/iris_val_based_early_stopping.py')
-rw-r--r--tensorflow/examples/skflow/iris_val_based_early_stopping.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/examples/skflow/iris_val_based_early_stopping.py b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
index 05dfa96a07..70dd8053aa 100644
--- a/tensorflow/examples/skflow/iris_val_based_early_stopping.py
+++ b/tensorflow/examples/skflow/iris_val_based_early_stopping.py
@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""Example of DNNClassifier for Iris plant dataset, with early stopping."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -35,6 +38,7 @@ def main(unused_argv):
# classifier with early stopping on training data
classifier1 = learn.DNNClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(x_train),
hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model/')
classifier1.fit(x=x_train, y=y_train, steps=2000)
score1 = metrics.accuracy_score(y_test, classifier1.predict(x_test))
@@ -42,6 +46,7 @@ def main(unused_argv):
# classifier with early stopping on validation data, save frequently for
# monitor to pick up new checkpoints.
classifier2 = learn.DNNClassifier(
+ feature_columns=learn.infer_real_valued_columns_from_input(x_train),
hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model_val/',
config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1))
classifier2.fit(x=x_train, y=y_train, steps=2000, monitors=[val_monitor])