diff options
author | Yunxing Dai <yunxing@google.com> | 2018-09-17 17:50:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 17:58:44 -0700 |
commit | a76646d4b4ad5d56b5e63c139985bbd1eb98dd90 (patch) | |
tree | 475914dcafa5ca137eab7febcef5fc62207ee9ac /tensorflow/contrib/tpu | |
parent | 185aa89912376d4088c22615908696cd30f9951b (diff) |
Add type checking at the beginning of tpu.shard().
Otherwise a message like "TypeError: Tensor objects are only iterable when eager execution is enabled. To iterate over this tensor use tf.map_fn." will be thrown, which is confusing.
PiperOrigin-RevId: 213371676
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 815a087a24..593f1d909e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -847,8 +847,12 @@ def shard(computation, if num_shards <= 0: raise ValueError("num_shards must be a positive integer.") + inputs = [] if inputs is None else inputs + if not isinstance(inputs, list): + raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.") + # Converts inputs to Tensors. - inputs = [] if inputs is None else [ops.convert_to_tensor(x) for x in inputs] + inputs = [ops.convert_to_tensor(x) for x in inputs] if input_shard_axes is None: input_shard_axes = [0] * len(inputs) |