aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/examples
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-02-15 19:12:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-15 19:16:17 -0800
commit98cf337e781977fd464c574656699b3181eddf19 (patch)
treedf5e6804542f99281422babe38fc74f8320c5e43 /third_party/examples
parent72bd433b9b6b06ae13893015361079dda992d3c8 (diff)
TFE SPINN example: use tensor instead of numpy array
in inference output. PiperOrigin-RevId: 185939805
Diffstat (limited to 'third_party/examples')
-rw-r--r--third_party/examples/eager/spinn/README.md4
-rw-r--r--third_party/examples/eager/spinn/spinn.py7
2 files changed, 5 insertions, 6 deletions
diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md
index 335c0fa3b5..7f477d1920 100644
--- a/third_party/examples/eager/spinn/README.md
+++ b/third_party/examples/eager/spinn/README.md
@@ -75,7 +75,7 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth
should all be separated by spaces. For instance,
```bash
- pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+ python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
--inference_premise '( ( The dog ) ( ( is running ) . ) )' \
--inference_hypothesis '( ( The dog ) ( moves . ) )'
```
@@ -93,7 +93,7 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth
By contrast, the following sentence pair:
```bash
- pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
+ python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \
--inference_premise '( ( The dog ) ( ( is running ) . ) )' \
--inference_hypothesis '( ( The dog ) ( rests . ) )'
```
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index 38ba48d501..8a1c7db2ea 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -44,7 +44,6 @@ import os
import sys
import time
-import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
@@ -567,7 +566,7 @@ def train_or_infer_spinn(embed,
Returns:
If `config.inference_premise ` and `config.inference_hypothesis` are not
`None`, i.e., inference mode: the logits for the possible labels of the
- SNLI data set, as numpy array of three floats.
+ SNLI data set, as a `Tensor` of three floats.
else:
The trainer object.
Raises:
@@ -626,8 +625,8 @@ def train_or_infer_spinn(embed,
inference_logits = model( # pylint: disable=not-callable
tf.constant(prem), tf.constant(prem_trans),
tf.constant(hypo), tf.constant(hypo_trans), training=False)
- inference_logits = np.array(inference_logits[0][1:])
- max_index = np.argmax(inference_logits)
+ inference_logits = inference_logits[0][1:]
+ max_index = tf.argmax(inference_logits)
print("\nInference logits:")
for i, (label, logit) in enumerate(
zip(data.POSSIBLE_LABELS, inference_logits)):