diff options
author | 2016-07-08 09:06:54 -0800 | |
---|---|---|
committer | 2016-07-08 10:17:45 -0700 | |
commit | c00c073f52c2fc7b6672022c75d0b2abb9d9af3a (patch) | |
tree | d5aad86afbf6697bcb6eaffc50c1c1f6f48cd0d0 /tensorflow/examples/skflow/iris_val_based_early_stopping.py | |
parent | 9a9219be3531d12c804f671e3e236a0d05c01d70 (diff) |
Begin removing feature column inference from linear and dnn estimators. Currently, the fit operation of each of them will infer feature columns from the passed in features. But it only works for dense float inputs.
Also, fixed some lint warnings.
Change: 126921818
Diffstat (limited to 'tensorflow/examples/skflow/iris_val_based_early_stopping.py')
-rw-r--r-- | tensorflow/examples/skflow/iris_val_based_early_stopping.py | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/examples/skflow/iris_val_based_early_stopping.py b/tensorflow/examples/skflow/iris_val_based_early_stopping.py index 05dfa96a07..70dd8053aa 100644 --- a/tensorflow/examples/skflow/iris_val_based_early_stopping.py +++ b/tensorflow/examples/skflow/iris_val_based_early_stopping.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Example of DNNClassifier for Iris plant dataset, with early stopping.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -35,6 +38,7 @@ def main(unused_argv): # classifier with early stopping on training data classifier1 = learn.DNNClassifier( + feature_columns=learn.infer_real_valued_columns_from_input(x_train), hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model/') classifier1.fit(x=x_train, y=y_train, steps=2000) score1 = metrics.accuracy_score(y_test, classifier1.predict(x_test)) @@ -42,6 +46,7 @@ def main(unused_argv): # classifier with early stopping on validation data, save frequently for # monitor to pick up new checkpoints. classifier2 = learn.DNNClassifier( + feature_columns=learn.infer_real_valued_columns_from_input(x_train), hidden_units=[10, 20, 10], n_classes=3, model_dir='/tmp/iris_model_val/', config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1)) classifier2.fit(x=x_train, y=y_train, steps=2000, monitors=[val_monitor]) |