diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-01-09 11:54:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-09 12:06:09 -0800 |
commit | 7ad7e4dfae4344d6b955b5eb61dc4b6bb792f1b3 (patch) | |
tree | ac821517862671c0d77bcbd67d1c911860324f67 | |
parent | c71ac2dce6bc73536956ff50e261c73f993af16c (diff) |
Move TensorForestEstimator to contrib, since that's where most of its code is and it will not be considered a canned estimator in the near future.
Change: 143989623
-rw-r--r-- | tensorflow/contrib/learn/BUILD | 19 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/BUILD | 32 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/__init__.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py (renamed from tensorflow/contrib/learn/python/learn/estimators/random_forest.py) | 1 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest_test.py (renamed from tensorflow/contrib/learn/python/learn/estimators/random_forest_test.py) | 2 | ||||
-rw-r--r-- | tensorflow/examples/learn/random_forest_mnist.py | 11 |
7 files changed, 39 insertions, 29 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index c7b2c2d427..81d1a1f9db 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -32,10 +32,6 @@ py_library( "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/contrib/session_bundle:exporter", "//tensorflow/contrib/session_bundle:gc", - "//tensorflow/contrib/tensor_forest:client_lib", - "//tensorflow/contrib/tensor_forest:data_ops_py", - "//tensorflow/contrib/tensor_forest:eval_metrics", - "//tensorflow/contrib/tensor_forest:tensor_forest_py", "//tensorflow/contrib/training:training_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -675,21 +671,6 @@ py_test( ) py_test( - name = "random_forest_test", - size = "medium", - srcs = ["python/learn/estimators/random_forest_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":learn", - "//tensorflow/contrib/learn/python/learn/datasets", - "//tensorflow/contrib/tensor_forest:tensor_forest_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//third_party/py/numpy", - ], -) - -py_test( name = "dynamic_rnn_estimator_test", size = "medium", srcs = ["python/learn/estimators/dynamic_rnn_estimator_test.py"], diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index be5a3d126a..11c196b702 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -322,8 +322,6 @@ from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModeKeys from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey -from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator -from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook from tensorflow.contrib.learn.python.learn.estimators.run_config import ClusterConfig from tensorflow.contrib.learn.python.learn.estimators.run_config import Environment from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index c705d80ded..a61d61c7e2 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -121,6 +121,7 @@ py_library( ":constants", ":data_ops_py", ":eval_metrics", + ":random_forest", ":tensor_forest_ops_py", ":tensor_forest_py", ], @@ -395,3 +396,34 @@ py_test( "//tensorflow/python:variables", ], ) + +py_library( + name = "random_forest", + srcs = ["client/random_forest.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_lib", + ":data_ops_py", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/learn", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", + ], +) + +py_test( + name = "random_forest_test", + size = "medium", + srcs = ["client/random_forest_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":random_forest", + ":tensor_forest_py", + "//tensorflow/contrib/learn/python/learn/datasets", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/tensor_forest/client/__init__.py b/tensorflow/contrib/tensor_forest/client/__init__.py index 1a0c87c4cc..335c1e4c43 100644 --- a/tensorflow/contrib/tensor_forest/client/__init__.py +++ b/tensorflow/contrib/tensor_forest/client/__init__.py @@ -19,4 +19,5 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.contrib.tensor_forest.client import eval_metrics +from tensorflow.contrib.tensor_forest.client import random_forest # pylint: enable=unused-import diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index c83109d5fe..b711fb22ff 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import framework as contrib_framework -from tensorflow.contrib.framework import deprecated_arg_values from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import trainable diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest_test.py b/tensorflow/contrib/tensor_forest/client/random_forest_test.py index d817116329..1e774dab2b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/random_forest_test.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest_test.py @@ -28,7 +28,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): import numpy as np from tensorflow.contrib.learn.python.learn.datasets import base -from tensorflow.contrib.learn.python.learn.estimators import random_forest +from tensorflow.contrib.tensor_forest.client import random_forest from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.python.platform import test diff --git a/tensorflow/examples/learn/random_forest_mnist.py b/tensorflow/examples/learn/random_forest_mnist.py index a34d52275a..6a943eb42e 100644 --- a/tensorflow/examples/learn/random_forest_mnist.py +++ b/tensorflow/examples/learn/random_forest_mnist.py @@ -21,25 +21,24 @@ import argparse import sys import tempfile -import tensorflow as tf - # pylint: disable=g-backslash-continuation from tensorflow.contrib.learn.python.learn\ import metric_spec -from tensorflow.contrib.learn.python.learn.estimators\ - import random_forest from tensorflow.contrib.tensor_forest.client\ import eval_metrics +from tensorflow.contrib.tensor_forest.client\ + import random_forest from tensorflow.contrib.tensor_forest.python\ import tensor_forest from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.python.platform import app FLAGS = None def build_estimator(model_dir): """Build an estimator.""" - params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams( + params = tensor_forest.ForestHParams( num_classes=10, num_features=784, num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes) graph_builder_class = tensor_forest.RandomForestGraphs @@ -129,4 +128,4 @@ if __name__ == '__main__': help='If true, use training loss as termination criteria.' ) FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) + app.run(main=main, argv=[sys.argv[0]] + unparsed) |