aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/learn/BUILD19
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/__init__.py2
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD32
-rw-r--r--tensorflow/contrib/tensor_forest/client/__init__.py1
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py (renamed from tensorflow/contrib/learn/python/learn/estimators/random_forest.py)1
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest_test.py (renamed from tensorflow/contrib/learn/python/learn/estimators/random_forest_test.py)2
-rw-r--r--tensorflow/examples/learn/random_forest_mnist.py11
7 files changed, 39 insertions, 29 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index c7b2c2d427..81d1a1f9db 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -32,10 +32,6 @@ py_library(
"//tensorflow/contrib/rnn:rnn_py",
"//tensorflow/contrib/session_bundle:exporter",
"//tensorflow/contrib/session_bundle:gc",
- "//tensorflow/contrib/tensor_forest:client_lib",
- "//tensorflow/contrib/tensor_forest:data_ops_py",
- "//tensorflow/contrib/tensor_forest:eval_metrics",
- "//tensorflow/contrib/tensor_forest:tensor_forest_py",
"//tensorflow/contrib/training:training_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@@ -675,21 +671,6 @@ py_test(
)
py_test(
- name = "random_forest_test",
- size = "medium",
- srcs = ["python/learn/estimators/random_forest_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":learn",
- "//tensorflow/contrib/learn/python/learn/datasets",
- "//tensorflow/contrib/tensor_forest:tensor_forest_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_test_lib",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
name = "dynamic_rnn_estimator_test",
size = "medium",
srcs = ["python/learn/estimators/dynamic_rnn_estimator_test.py"],
diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
index be5a3d126a..11c196b702 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
@@ -322,8 +322,6 @@ from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModeKeys
from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey
-from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
-from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
from tensorflow.contrib.learn.python.learn.estimators.run_config import ClusterConfig
from tensorflow.contrib.learn.python.learn.estimators.run_config import Environment
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index c705d80ded..a61d61c7e2 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -121,6 +121,7 @@ py_library(
":constants",
":data_ops_py",
":eval_metrics",
+ ":random_forest",
":tensor_forest_ops_py",
":tensor_forest_py",
],
@@ -395,3 +396,34 @@ py_test(
"//tensorflow/python:variables",
],
)
+
+py_library(
+ name = "random_forest",
+ srcs = ["client/random_forest.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_lib",
+ ":data_ops_py",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/learn",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:state_ops",
+ ],
+)
+
+py_test(
+ name = "random_forest_test",
+ size = "medium",
+ srcs = ["client/random_forest_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":random_forest",
+ ":tensor_forest_py",
+ "//tensorflow/contrib/learn/python/learn/datasets",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/tensor_forest/client/__init__.py b/tensorflow/contrib/tensor_forest/client/__init__.py
index 1a0c87c4cc..335c1e4c43 100644
--- a/tensorflow/contrib/tensor_forest/client/__init__.py
+++ b/tensorflow/contrib/tensor_forest/client/__init__.py
@@ -19,4 +19,5 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.tensor_forest.client import eval_metrics
+from tensorflow.contrib.tensor_forest.client import random_forest
# pylint: enable=unused-import
diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index c83109d5fe..b711fb22ff 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import framework as contrib_framework
-from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import trainable
diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest_test.py b/tensorflow/contrib/tensor_forest/client/random_forest_test.py
index d817116329..1e774dab2b 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/random_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest_test.py
@@ -28,7 +28,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
import numpy as np
from tensorflow.contrib.learn.python.learn.datasets import base
-from tensorflow.contrib.learn.python.learn.estimators import random_forest
+from tensorflow.contrib.tensor_forest.client import random_forest
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.platform import test
diff --git a/tensorflow/examples/learn/random_forest_mnist.py b/tensorflow/examples/learn/random_forest_mnist.py
index a34d52275a..6a943eb42e 100644
--- a/tensorflow/examples/learn/random_forest_mnist.py
+++ b/tensorflow/examples/learn/random_forest_mnist.py
@@ -21,25 +21,24 @@ import argparse
import sys
import tempfile
-import tensorflow as tf
-
# pylint: disable=g-backslash-continuation
from tensorflow.contrib.learn.python.learn\
import metric_spec
-from tensorflow.contrib.learn.python.learn.estimators\
- import random_forest
from tensorflow.contrib.tensor_forest.client\
import eval_metrics
+from tensorflow.contrib.tensor_forest.client\
+ import random_forest
from tensorflow.contrib.tensor_forest.python\
import tensor_forest
from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.python.platform import app
FLAGS = None
def build_estimator(model_dir):
"""Build an estimator."""
- params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
+ params = tensor_forest.ForestHParams(
num_classes=10, num_features=784,
num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
graph_builder_class = tensor_forest.RandomForestGraphs
@@ -129,4 +128,4 @@ if __name__ == '__main__':
help='If true, use training loss as termination criteria.'
)
FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)