diff options
author | Martin Wicke <wicke@google.com> | 2017-03-16 00:02:11 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-16 01:33:46 -0700 |
commit | bed0e5c3292bcc094a2890183cfaec8273541fff (patch) | |
tree | 46114f63f2e97197207837413aa12c0eff388d7d /tensorflow/python/estimator/model_fn_test.py | |
parent | 1ce242420ecb64b196f29d510f51e795c92696ca (diff) |
Expose Estimator and associated utilities in the API.
Change: 150292011
Diffstat (limited to 'tensorflow/python/estimator/model_fn_test.py')
-rw-r--r-- | tensorflow/python/estimator/model_fn_test.py | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index 935272bd88..7ad055862b 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -19,7 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.estimator import export_output +from tensorflow.python.estimator import export from tensorflow.python.estimator import model_fn from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops @@ -67,7 +67,7 @@ class EstimatorSpecTrainTest(test.TestCase): train_op=control_flow_ops.no_op(), eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)}, export_outputs={ - 'head_name': export_output.ClassificationOutput(classes=classes) + 'head_name': export.ClassificationOutput(classes=classes) }, training_chief_hooks=[_FakeHook()], training_hooks=[_FakeHook()], @@ -217,7 +217,7 @@ class EstimatorSpecEvalTest(test.TestCase): train_op=control_flow_ops.no_op(), eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)}, export_outputs={ - 'head_name': export_output.ClassificationOutput(classes=classes) + 'head_name': export.ClassificationOutput(classes=classes) }, training_chief_hooks=[_FakeHook()], training_hooks=[_FakeHook()], @@ -401,7 +401,7 @@ class EstimatorSpecInferTest(test.TestCase): train_op=control_flow_ops.no_op(), eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)}, export_outputs={ - 'head_name': export_output.ClassificationOutput(classes=classes) + 'head_name': export.ClassificationOutput(classes=classes) }, training_chief_hooks=[_FakeHook()], training_hooks=[_FakeHook()], @@ -446,7 +446,7 @@ class EstimatorSpecInferTest(test.TestCase): model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, - export_outputs=export_output.ClassificationOutput(classes=classes)) + export_outputs=export.ClassificationOutput(classes=classes)) def testExportOutputsValueNotExportOutput(self): with ops.Graph().as_default(), self.test_session(): @@ -465,7 +465,7 @@ class EstimatorSpecInferTest(test.TestCase): with ops.Graph().as_default(), self.test_session(): predictions = {'loss': constant_op.constant(1.)} output_1 = constant_op.constant([1.]) - regression_output = export_output.RegressionOutput(value=output_1) + regression_output = export.RegressionOutput(value=output_1) export_outputs = { 'head-1': regression_output, } @@ -488,9 +488,9 @@ class EstimatorSpecInferTest(test.TestCase): output_3 = constant_op.constant(['3']) export_outputs = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: - export_output.RegressionOutput(value=output_1), - 'head-2': export_output.ClassificationOutput(classes=output_2), - 'head-3': export_output.PredictOutput(outputs={ + export.RegressionOutput(value=output_1), + 'head-2': export.ClassificationOutput(classes=output_2), + 'head-3': export.PredictOutput(outputs={ 'some_output_3': output_3 })} estimator_spec = model_fn.EstimatorSpec( @@ -506,9 +506,9 @@ class EstimatorSpecInferTest(test.TestCase): output_2 = constant_op.constant(['2']) output_3 = constant_op.constant(['3']) export_outputs = { - 'head-1': export_output.RegressionOutput(value=output_1), - 'head-2': export_output.ClassificationOutput(classes=output_2), - 'head-3': export_output.PredictOutput(outputs={ + 'head-1': export.RegressionOutput(value=output_1), + 'head-2': export.ClassificationOutput(classes=output_2), + 'head-3': export.PredictOutput(outputs={ 'some_output_3': output_3 })} with self.assertRaisesRegexp( |