diff options
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_estimator.py')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 168726a6b3..0f15133d74 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -171,7 +171,6 @@ class TpuEstimator(estimator_lib.Estimator): Note: TpuEstimator transforms a global batch size in params to a per-shard batch size when calling the input_fn. """ - def __init__(self, model_fn=None, model_dir=None, @@ -196,6 +195,8 @@ class TpuEstimator(estimator_lib.Estimator): .format(params[_BATCH_SIZE_KEY], config.tpu_config.num_shards)) if use_tpu: + if not isinstance(config, tpu_config.RunConfig): + raise ValueError('`config` must be `tpu_config.RunConfig`') # Verifies the model_fn signature according to Estimator framework. estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access # We cannot store config and params in this constructor as parent @@ -204,7 +205,6 @@ class TpuEstimator(estimator_lib.Estimator): model_function = wrapped_model_fn(model_fn) else: model_function = model_fn - super(TpuEstimator, self).__init__( model_fn=model_function, model_dir=model_dir, |