aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Yash Katariya <yash.katariya10@gmail.com>2018-08-17 17:06:47 -0400
committerGravatar Yash Katariya <yash.katariya10@gmail.com>2018-08-22 10:15:21 -0400
commit4c2f6aeaaf4aeafccc85a289a5a105d52738b410 (patch)
tree6e51f61f3a6543c03e0c879c58e7374484378e52 /tensorflow/contrib/eager
parentce40173f61c79af05dcd0e0330cdb80bb179585d (diff)
Simplyfing the evaluation step by taking argmax of the softmax of the predictions instead of tf.multinomial
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb2
3 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
index 315d7a4893..e0f7137184 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
@@ -1056,7 +1056,7 @@
"\n",
" attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
"\n",
- " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n",
" result.append(index_word[predicted_id])\n",
"\n",
" if index_word[predicted_id] == '<end>':\n",
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
index 40bc098724..b13e5aae9b 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
@@ -610,7 +610,7 @@
"\n",
" # using a multinomial distribution to predict the word returned by the model\n",
" predictions = predictions / temperature\n",
- " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n",
" \n",
" # We pass the predicted word as the next input to the model\n",
" # along with the previous hidden state\n",
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
index f1e1f99c57..3e02d9fbb0 100644
--- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
@@ -677,7 +677,7 @@
" attention_weights = tf.reshape(attention_weights, (-1, ))\n",
" attention_plot[t] = attention_weights.numpy()\n",
"\n",
- " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n",
"\n",
" result += targ_lang.idx2word[predicted_id] + ' '\n",
"\n",