aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-09-17 17:50:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 17:58:44 -0700
commita76646d4b4ad5d56b5e63c139985bbd1eb98dd90 (patch)
tree475914dcafa5ca137eab7febcef5fc62207ee9ac /tensorflow/contrib/tpu
parent185aa89912376d4088c22615908696cd30f9951b (diff)
Add type checking at the beginning of tpu.shard().
Otherwise a message like "TypeError: Tensor objects are only iterable when eager execution is enabled. To iterate over this tensor use tf.map_fn." will be thrown, which is confusing. PiperOrigin-RevId: 213371676
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 815a087a24..593f1d909e 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -847,8 +847,12 @@ def shard(computation,
if num_shards <= 0:
raise ValueError("num_shards must be a positive integer.")
+ inputs = [] if inputs is None else inputs
+ if not isinstance(inputs, list):
+ raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.")
+
# Converts inputs to Tensors.
- inputs = [] if inputs is None else [ops.convert_to_tensor(x) for x in inputs]
+ inputs = [ops.convert_to_tensor(x) for x in inputs]
if input_shard_axes is None:
input_shard_axes = [0] * len(inputs)