diff options
author | Mustafa Ispir <ispir@google.com> | 2017-02-24 14:51:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-24 15:54:52 -0800 |
commit | 3e368f12824dcb2961e1d28647b85ef84b23e5c4 (patch) | |
tree | 48be80cfffe25cb7b61a32af85e4fb32b12aede1 /tensorflow/python/estimator/model_fn_test.py | |
parent | 54690047e3c7e9673f3bc8bf10d2b66337544dd4 (diff) |
Move estimator.fit into core.
Change: 148507023
Diffstat (limited to 'tensorflow/python/estimator/model_fn_test.py')
-rw-r--r-- | tensorflow/python/estimator/model_fn_test.py | 79 |
1 files changed, 41 insertions, 38 deletions
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index 354679a1f7..b80b965196 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -35,11 +35,11 @@ class _FakeHook(session_run_hook.SessionRunHook): class _InvalidHook(object): - """Invalid hook (not a subclass of `SessionRunHook`.""" + """Invalid hook (not a subclass of `SessionRunHook`).""" class _InvalidScaffold(object): - """Invalid scaffold (not a subclass of `Scaffold`.""" + """Invalid scaffold (not a subclass of `Scaffold`).""" class EstimatorSpecTrainTest(test.TestCase): @@ -49,7 +49,7 @@ class EstimatorSpecTrainTest(test.TestCase): """Tests that no errors are raised when all required arguments are set.""" with ops.Graph().as_default(), self.test_session(): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, loss=constant_op.constant(1.), train_op=control_flow_ops.no_op()) @@ -59,7 +59,7 @@ class EstimatorSpecTrainTest(test.TestCase): loss = constant_op.constant(1.) predictions = {'loss': loss} model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, predictions=predictions, loss=loss, train_op=control_flow_ops.no_op(), @@ -77,7 +77,7 @@ class EstimatorSpecTrainTest(test.TestCase): with ops.Graph().as_default(), self.test_session(): with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, loss=1., train_op=control_flow_ops.no_op()) @@ -85,7 +85,7 @@ class EstimatorSpecTrainTest(test.TestCase): """Tests that no errors are raised when loss is 1D tensor.""" with ops.Graph().as_default(), self.test_session(): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, loss=constant_op.constant([1.]), train_op=control_flow_ops.no_op()) @@ -93,14 +93,13 @@ class EstimatorSpecTrainTest(test.TestCase): with ops.Graph().as_default(), self.test_session(): with self.assertRaisesRegexp(ValueError, 'Missing loss'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, - train_op=control_flow_ops.no_op()) + mode=model_fn.ModeKeys.FIT, train_op=control_flow_ops.no_op()) def testLossNotScalar(self): with ops.Graph().as_default(), self.test_session(): with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, loss=constant_op.constant([1., 2.]), train_op=control_flow_ops.no_op()) @@ -112,7 +111,7 @@ class EstimatorSpecTrainTest(test.TestCase): dense_shape=[1]) with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, loss=loss, train_op=control_flow_ops.no_op()) @@ -123,7 +122,7 @@ class EstimatorSpecTrainTest(test.TestCase): with self.assertRaisesRegexp( ValueError, 'must be from the default graph'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, loss=loss, train_op=control_flow_ops.no_op()) @@ -131,16 +130,16 @@ class EstimatorSpecTrainTest(test.TestCase): with ops.Graph().as_default(), self.test_session(): with self.assertRaisesRegexp(ValueError, 'Missing train_op'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, - loss=constant_op.constant(1.)) + mode=model_fn.ModeKeys.FIT, loss=constant_op.constant(1.)) - def testTrainOpNotOperation(self): + def testTrainOpNotOperationAndTensor(self): with ops.Graph().as_default(), self.test_session(): - with self.assertRaisesRegexp(TypeError, 'train_op must be Operation'): + with self.assertRaisesRegexp(TypeError, + 'train_op must be Operation or Tensor'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, loss=constant_op.constant(1.), - train_op=constant_op.constant(1.)) + train_op='Not an Operation or Tensor') def testTrainOpFromDifferentGraph(self): with ops.Graph().as_default(): @@ -149,7 +148,7 @@ class EstimatorSpecTrainTest(test.TestCase): with self.assertRaisesRegexp( ValueError, 'must be from the default graph'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, loss=constant_op.constant(1.), train_op=train_op) @@ -158,7 +157,7 @@ class EstimatorSpecTrainTest(test.TestCase): with self.assertRaisesRegexp( TypeError, 'All hooks must be SessionRunHook instances'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, loss=constant_op.constant(1.), train_op=control_flow_ops.no_op(), training_chief_hooks=[_InvalidHook()]) @@ -168,7 +167,7 @@ class EstimatorSpecTrainTest(test.TestCase): with self.assertRaisesRegexp( TypeError, 'All hooks must be SessionRunHook instances'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, loss=constant_op.constant(1.), train_op=control_flow_ops.no_op(), training_hooks=[_InvalidHook()]) @@ -178,7 +177,7 @@ class EstimatorSpecTrainTest(test.TestCase): with self.assertRaisesRegexp( TypeError, r'scaffold must be tf\.train\.Scaffold'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.TRAIN, + mode=model_fn.ModeKeys.FIT, loss=constant_op.constant(1.), train_op=control_flow_ops.no_op(), scaffold=_InvalidScaffold()) @@ -346,6 +345,16 @@ class EstimatorSpecEvalTest(test.TestCase): loss=loss, eval_metric_ops={'loss': loss}) + def testEvalMetricOpsNoTensorOrOperation(self): + with ops.Graph().as_default(), self.test_session(): + loss = constant_op.constant(1.) + with self.assertRaisesRegexp(TypeError, 'must be Operation or Tensor'): + model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.EVAL, + predictions={'loss': loss}, + loss=loss, + eval_metric_ops={'loss': ('NonTensor', loss)}) + def testEvalMetricOpsFromDifferentGraph(self): with ops.Graph().as_default(): eval_metric_ops = { @@ -368,7 +377,7 @@ class EstimatorSpecInferTest(test.TestCase): """Tests that no errors are raised when all required arguments are set.""" with ops.Graph().as_default(), self.test_session(): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.INFER, + mode=model_fn.ModeKeys.PREDICT, predictions={'loss': constant_op.constant(1.)}) def testAllArgumentsSet(self): @@ -377,7 +386,7 @@ class EstimatorSpecInferTest(test.TestCase): loss = constant_op.constant(1.) predictions = {'loss': loss} model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.INFER, + mode=model_fn.ModeKeys.PREDICT, predictions=predictions, loss=loss, train_op=control_flow_ops.no_op(), @@ -393,23 +402,20 @@ class EstimatorSpecInferTest(test.TestCase): def testPredictionsMissing(self): with ops.Graph().as_default(), self.test_session(): with self.assertRaisesRegexp(ValueError, 'Missing predictions'): - model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.INFER) + model_fn.EstimatorSpec(mode=model_fn.ModeKeys.PREDICT) def testPredictionsTensor(self): """Tests that no error is raised when predictions is Tensor (not dict).""" with ops.Graph().as_default(), self.test_session(): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.INFER, - predictions=constant_op.constant(1.)) + mode=model_fn.ModeKeys.PREDICT, predictions=constant_op.constant(1.)) def testPredictionsNumber(self): with ops.Graph().as_default(), self.test_session(): with self.assertRaisesRegexp( TypeError, r'predictions\[number\] must be Tensor'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.INFER, - predictions={'number': 1.}) + mode=model_fn.ModeKeys.PREDICT, predictions={'number': 1.}) def testPredictionsSparseTensor(self): with ops.Graph().as_default(), self.test_session(): @@ -421,8 +427,7 @@ class EstimatorSpecInferTest(test.TestCase): with self.assertRaisesRegexp( TypeError, r'predictions\[sparse\] must be Tensor'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.INFER, - predictions=predictions) + mode=model_fn.ModeKeys.PREDICT, predictions=predictions) def testExportOutputsNoDict(self): with ops.Graph().as_default(), self.test_session(): @@ -430,7 +435,7 @@ class EstimatorSpecInferTest(test.TestCase): with self.assertRaisesRegexp( TypeError, 'export_outputs must be dict'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.INFER, + mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs=(signature_constants.CLASSIFY_METHOD_NAME, predictions)) @@ -441,7 +446,7 @@ class EstimatorSpecInferTest(test.TestCase): with self.assertRaisesRegexp( TypeError, 'Values in export_outputs must be 2-tuple'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.INFER, + mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={'head_name': predictions}) @@ -451,7 +456,7 @@ class EstimatorSpecInferTest(test.TestCase): with self.assertRaisesRegexp( TypeError, 'Values in export_outputs must be 2-tuple'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.INFER, + mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={'head_name': (predictions,)}) @@ -461,11 +466,9 @@ class EstimatorSpecInferTest(test.TestCase): with self.assertRaisesRegexp( ValueError, 'Invalid signature_method_name in export_outputs'): model_fn.EstimatorSpec( - mode=model_fn.ModeKeys.INFER, + mode=model_fn.ModeKeys.PREDICT, predictions=predictions, - export_outputs={ - 'head_name': ('invalid/method/name', predictions) - }) + export_outputs={'head_name': ('invalid/method/name', predictions)}) if __name__ == '__main__': |