aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_estimator.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 168726a6b3..0f15133d74 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -171,7 +171,6 @@ class TpuEstimator(estimator_lib.Estimator):
Note: TpuEstimator transforms a global batch size in params to a per-shard
batch size when calling the input_fn.
"""
-
def __init__(self,
model_fn=None,
model_dir=None,
@@ -196,6 +195,8 @@ class TpuEstimator(estimator_lib.Estimator):
.format(params[_BATCH_SIZE_KEY], config.tpu_config.num_shards))
if use_tpu:
+ if not isinstance(config, tpu_config.RunConfig):
+ raise ValueError('`config` must be `tpu_config.RunConfig`')
# Verifies the model_fn signature according to Estimator framework.
estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access
# We cannot store config and params in this constructor as parent
@@ -204,7 +205,6 @@ class TpuEstimator(estimator_lib.Estimator):
model_function = wrapped_model_fn(model_fn)
else:
model_function = model_fn
-
super(TpuEstimator, self).__init__(
model_fn=model_function,
model_dir=model_dir,