aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/model_fn_test.py
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-02-24 14:51:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-24 15:54:52 -0800
commit3e368f12824dcb2961e1d28647b85ef84b23e5c4 (patch)
tree48be80cfffe25cb7b61a32af85e4fb32b12aede1 /tensorflow/python/estimator/model_fn_test.py
parent54690047e3c7e9673f3bc8bf10d2b66337544dd4 (diff)
Move estimator.fit into core.
Change: 148507023
Diffstat (limited to 'tensorflow/python/estimator/model_fn_test.py')
-rw-r--r--tensorflow/python/estimator/model_fn_test.py79
1 files changed, 41 insertions, 38 deletions
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py
index 354679a1f7..b80b965196 100644
--- a/tensorflow/python/estimator/model_fn_test.py
+++ b/tensorflow/python/estimator/model_fn_test.py
@@ -35,11 +35,11 @@ class _FakeHook(session_run_hook.SessionRunHook):
class _InvalidHook(object):
- """Invalid hook (not a subclass of `SessionRunHook`."""
+ """Invalid hook (not a subclass of `SessionRunHook`)."""
class _InvalidScaffold(object):
- """Invalid scaffold (not a subclass of `Scaffold`."""
+ """Invalid scaffold (not a subclass of `Scaffold`)."""
class EstimatorSpecTrainTest(test.TestCase):
@@ -49,7 +49,7 @@ class EstimatorSpecTrainTest(test.TestCase):
"""Tests that no errors are raised when all required arguments are set."""
with ops.Graph().as_default(), self.test_session():
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
loss=constant_op.constant(1.),
train_op=control_flow_ops.no_op())
@@ -59,7 +59,7 @@ class EstimatorSpecTrainTest(test.TestCase):
loss = constant_op.constant(1.)
predictions = {'loss': loss}
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
predictions=predictions,
loss=loss,
train_op=control_flow_ops.no_op(),
@@ -77,7 +77,7 @@ class EstimatorSpecTrainTest(test.TestCase):
with ops.Graph().as_default(), self.test_session():
with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
loss=1.,
train_op=control_flow_ops.no_op())
@@ -85,7 +85,7 @@ class EstimatorSpecTrainTest(test.TestCase):
"""Tests that no errors are raised when loss is 1D tensor."""
with ops.Graph().as_default(), self.test_session():
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
loss=constant_op.constant([1.]),
train_op=control_flow_ops.no_op())
@@ -93,14 +93,13 @@ class EstimatorSpecTrainTest(test.TestCase):
with ops.Graph().as_default(), self.test_session():
with self.assertRaisesRegexp(ValueError, 'Missing loss'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
- train_op=control_flow_ops.no_op())
+ mode=model_fn.ModeKeys.FIT, train_op=control_flow_ops.no_op())
def testLossNotScalar(self):
with ops.Graph().as_default(), self.test_session():
with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
loss=constant_op.constant([1., 2.]),
train_op=control_flow_ops.no_op())
@@ -112,7 +111,7 @@ class EstimatorSpecTrainTest(test.TestCase):
dense_shape=[1])
with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
loss=loss,
train_op=control_flow_ops.no_op())
@@ -123,7 +122,7 @@ class EstimatorSpecTrainTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
loss=loss,
train_op=control_flow_ops.no_op())
@@ -131,16 +130,16 @@ class EstimatorSpecTrainTest(test.TestCase):
with ops.Graph().as_default(), self.test_session():
with self.assertRaisesRegexp(ValueError, 'Missing train_op'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
- loss=constant_op.constant(1.))
+ mode=model_fn.ModeKeys.FIT, loss=constant_op.constant(1.))
- def testTrainOpNotOperation(self):
+ def testTrainOpNotOperationAndTensor(self):
with ops.Graph().as_default(), self.test_session():
- with self.assertRaisesRegexp(TypeError, 'train_op must be Operation'):
+ with self.assertRaisesRegexp(TypeError,
+ 'train_op must be Operation or Tensor'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
loss=constant_op.constant(1.),
- train_op=constant_op.constant(1.))
+ train_op='Not an Operation or Tensor')
def testTrainOpFromDifferentGraph(self):
with ops.Graph().as_default():
@@ -149,7 +148,7 @@ class EstimatorSpecTrainTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
loss=constant_op.constant(1.),
train_op=train_op)
@@ -158,7 +157,7 @@ class EstimatorSpecTrainTest(test.TestCase):
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
loss=constant_op.constant(1.),
train_op=control_flow_ops.no_op(),
training_chief_hooks=[_InvalidHook()])
@@ -168,7 +167,7 @@ class EstimatorSpecTrainTest(test.TestCase):
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
loss=constant_op.constant(1.),
train_op=control_flow_ops.no_op(),
training_hooks=[_InvalidHook()])
@@ -178,7 +177,7 @@ class EstimatorSpecTrainTest(test.TestCase):
with self.assertRaisesRegexp(
TypeError, r'scaffold must be tf\.train\.Scaffold'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.TRAIN,
+ mode=model_fn.ModeKeys.FIT,
loss=constant_op.constant(1.),
train_op=control_flow_ops.no_op(),
scaffold=_InvalidScaffold())
@@ -346,6 +345,16 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss,
eval_metric_ops={'loss': loss})
+ def testEvalMetricOpsNoTensorOrOperation(self):
+ with ops.Graph().as_default(), self.test_session():
+ loss = constant_op.constant(1.)
+ with self.assertRaisesRegexp(TypeError, 'must be Operation or Tensor'):
+ model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ predictions={'loss': loss},
+ loss=loss,
+ eval_metric_ops={'loss': ('NonTensor', loss)})
+
def testEvalMetricOpsFromDifferentGraph(self):
with ops.Graph().as_default():
eval_metric_ops = {
@@ -368,7 +377,7 @@ class EstimatorSpecInferTest(test.TestCase):
"""Tests that no errors are raised when all required arguments are set."""
with ops.Graph().as_default(), self.test_session():
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.INFER,
+ mode=model_fn.ModeKeys.PREDICT,
predictions={'loss': constant_op.constant(1.)})
def testAllArgumentsSet(self):
@@ -377,7 +386,7 @@ class EstimatorSpecInferTest(test.TestCase):
loss = constant_op.constant(1.)
predictions = {'loss': loss}
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.INFER,
+ mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
loss=loss,
train_op=control_flow_ops.no_op(),
@@ -393,23 +402,20 @@ class EstimatorSpecInferTest(test.TestCase):
def testPredictionsMissing(self):
with ops.Graph().as_default(), self.test_session():
with self.assertRaisesRegexp(ValueError, 'Missing predictions'):
- model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.INFER)
+ model_fn.EstimatorSpec(mode=model_fn.ModeKeys.PREDICT)
def testPredictionsTensor(self):
"""Tests that no error is raised when predictions is Tensor (not dict)."""
with ops.Graph().as_default(), self.test_session():
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.INFER,
- predictions=constant_op.constant(1.))
+ mode=model_fn.ModeKeys.PREDICT, predictions=constant_op.constant(1.))
def testPredictionsNumber(self):
with ops.Graph().as_default(), self.test_session():
with self.assertRaisesRegexp(
TypeError, r'predictions\[number\] must be Tensor'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.INFER,
- predictions={'number': 1.})
+ mode=model_fn.ModeKeys.PREDICT, predictions={'number': 1.})
def testPredictionsSparseTensor(self):
with ops.Graph().as_default(), self.test_session():
@@ -421,8 +427,7 @@ class EstimatorSpecInferTest(test.TestCase):
with self.assertRaisesRegexp(
TypeError, r'predictions\[sparse\] must be Tensor'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.INFER,
- predictions=predictions)
+ mode=model_fn.ModeKeys.PREDICT, predictions=predictions)
def testExportOutputsNoDict(self):
with ops.Graph().as_default(), self.test_session():
@@ -430,7 +435,7 @@ class EstimatorSpecInferTest(test.TestCase):
with self.assertRaisesRegexp(
TypeError, 'export_outputs must be dict'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.INFER,
+ mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs=(signature_constants.CLASSIFY_METHOD_NAME,
predictions))
@@ -441,7 +446,7 @@ class EstimatorSpecInferTest(test.TestCase):
with self.assertRaisesRegexp(
TypeError, 'Values in export_outputs must be 2-tuple'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.INFER,
+ mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={'head_name': predictions})
@@ -451,7 +456,7 @@ class EstimatorSpecInferTest(test.TestCase):
with self.assertRaisesRegexp(
TypeError, 'Values in export_outputs must be 2-tuple'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.INFER,
+ mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={'head_name': (predictions,)})
@@ -461,11 +466,9 @@ class EstimatorSpecInferTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError, 'Invalid signature_method_name in export_outputs'):
model_fn.EstimatorSpec(
- mode=model_fn.ModeKeys.INFER,
+ mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
- export_outputs={
- 'head_name': ('invalid/method/name', predictions)
- })
+ export_outputs={'head_name': ('invalid/method/name', predictions)})
if __name__ == '__main__':