aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/model_fn_test.py
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2017-03-16 00:02:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-16 01:33:46 -0700
commitbed0e5c3292bcc094a2890183cfaec8273541fff (patch)
tree46114f63f2e97197207837413aa12c0eff388d7d /tensorflow/python/estimator/model_fn_test.py
parent1ce242420ecb64b196f29d510f51e795c92696ca (diff)
Expose Estimator and associated utilities in the API.
Change: 150292011
Diffstat (limited to 'tensorflow/python/estimator/model_fn_test.py')
-rw-r--r--tensorflow/python/estimator/model_fn_test.py24
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py
index 935272bd88..7ad055862b 100644
--- a/tensorflow/python/estimator/model_fn_test.py
+++ b/tensorflow/python/estimator/model_fn_test.py
@@ -19,7 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.estimator import export_output
+from tensorflow.python.estimator import export
from tensorflow.python.estimator import model_fn
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
@@ -67,7 +67,7 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=control_flow_ops.no_op(),
eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
export_outputs={
- 'head_name': export_output.ClassificationOutput(classes=classes)
+ 'head_name': export.ClassificationOutput(classes=classes)
},
training_chief_hooks=[_FakeHook()],
training_hooks=[_FakeHook()],
@@ -217,7 +217,7 @@ class EstimatorSpecEvalTest(test.TestCase):
train_op=control_flow_ops.no_op(),
eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
export_outputs={
- 'head_name': export_output.ClassificationOutput(classes=classes)
+ 'head_name': export.ClassificationOutput(classes=classes)
},
training_chief_hooks=[_FakeHook()],
training_hooks=[_FakeHook()],
@@ -401,7 +401,7 @@ class EstimatorSpecInferTest(test.TestCase):
train_op=control_flow_ops.no_op(),
eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
export_outputs={
- 'head_name': export_output.ClassificationOutput(classes=classes)
+ 'head_name': export.ClassificationOutput(classes=classes)
},
training_chief_hooks=[_FakeHook()],
training_hooks=[_FakeHook()],
@@ -446,7 +446,7 @@ class EstimatorSpecInferTest(test.TestCase):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
- export_outputs=export_output.ClassificationOutput(classes=classes))
+ export_outputs=export.ClassificationOutput(classes=classes))
def testExportOutputsValueNotExportOutput(self):
with ops.Graph().as_default(), self.test_session():
@@ -465,7 +465,7 @@ class EstimatorSpecInferTest(test.TestCase):
with ops.Graph().as_default(), self.test_session():
predictions = {'loss': constant_op.constant(1.)}
output_1 = constant_op.constant([1.])
- regression_output = export_output.RegressionOutput(value=output_1)
+ regression_output = export.RegressionOutput(value=output_1)
export_outputs = {
'head-1': regression_output,
}
@@ -488,9 +488,9 @@ class EstimatorSpecInferTest(test.TestCase):
output_3 = constant_op.constant(['3'])
export_outputs = {
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
- export_output.RegressionOutput(value=output_1),
- 'head-2': export_output.ClassificationOutput(classes=output_2),
- 'head-3': export_output.PredictOutput(outputs={
+ export.RegressionOutput(value=output_1),
+ 'head-2': export.ClassificationOutput(classes=output_2),
+ 'head-3': export.PredictOutput(outputs={
'some_output_3': output_3
})}
estimator_spec = model_fn.EstimatorSpec(
@@ -506,9 +506,9 @@ class EstimatorSpecInferTest(test.TestCase):
output_2 = constant_op.constant(['2'])
output_3 = constant_op.constant(['3'])
export_outputs = {
- 'head-1': export_output.RegressionOutput(value=output_1),
- 'head-2': export_output.ClassificationOutput(classes=output_2),
- 'head-3': export_output.PredictOutput(outputs={
+ 'head-1': export.RegressionOutput(value=output_1),
+ 'head-2': export.ClassificationOutput(classes=output_2),
+ 'head-3': export.PredictOutput(outputs={
'some_output_3': output_3
})}
with self.assertRaisesRegexp(