diff options
author | Shanqing Cai <cais@google.com> | 2018-02-15 19:12:05 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-15 19:16:17 -0800 |
commit | 98cf337e781977fd464c574656699b3181eddf19 (patch) | |
tree | df5e6804542f99281422babe38fc74f8320c5e43 /third_party/examples | |
parent | 72bd433b9b6b06ae13893015361079dda992d3c8 (diff) |
TFE SPINN example: use tensor instead of numpy array
in inference output.
PiperOrigin-RevId: 185939805
Diffstat (limited to 'third_party/examples')
-rw-r--r-- | third_party/examples/eager/spinn/README.md | 4 | ||||
-rw-r--r-- | third_party/examples/eager/spinn/spinn.py | 7 |
2 files changed, 5 insertions, 6 deletions
diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md index 335c0fa3b5..7f477d1920 100644 --- a/third_party/examples/eager/spinn/README.md +++ b/third_party/examples/eager/spinn/README.md @@ -75,7 +75,7 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth should all be separated by spaces. For instance, ```bash - pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \ + python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \ --inference_premise '( ( The dog ) ( ( is running ) . ) )' \ --inference_hypothesis '( ( The dog ) ( moves . ) )' ``` @@ -93,7 +93,7 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth By contrast, the following sentence pair: ```bash - pythons spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \ + python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \ --inference_premise '( ( The dog ) ( ( is running ) . ) )' \ --inference_hypothesis '( ( The dog ) ( rests . ) )' ``` diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index 38ba48d501..8a1c7db2ea 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -44,7 +44,6 @@ import os import sys import time -import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf @@ -567,7 +566,7 @@ def train_or_infer_spinn(embed, Returns: If `config.inference_premise ` and `config.inference_hypothesis` are not `None`, i.e., inference mode: the logits for the possible labels of the - SNLI data set, as numpy array of three floats. + SNLI data set, as a `Tensor` of three floats. else: The trainer object. Raises: @@ -626,8 +625,8 @@ def train_or_infer_spinn(embed, inference_logits = model( # pylint: disable=not-callable tf.constant(prem), tf.constant(prem_trans), tf.constant(hypo), tf.constant(hypo_trans), training=False) - inference_logits = np.array(inference_logits[0][1:]) - max_index = np.argmax(inference_logits) + inference_logits = inference_logits[0][1:] + max_index = tf.argmax(inference_logits) print("\nInference logits:") for i, (label, logit) in enumerate( zip(data.POSSIBLE_LABELS, inference_logits)): |