diff options
author | 2017-04-19 15:22:04 -0800 | |
---|---|---|
committer | 2017-04-19 16:44:19 -0700 | |
commit | 72164f00a95eaa8468fe00e84fe356c753e2c164 (patch) | |
tree | bf3be6c35e282867095089336599abe351a3a559 /tensorflow/contrib/seq2seq/BUILD | |
parent | 4a61e0bf03542c698ea045acb71e8f6e48db60d7 (diff) |
BeamSearchDecoder:
1. Implemented in contrib/seq2seq.
2. Includes changes to attention_wrapper.py to accommodate the batch_size tiling.
3. Includes changes to decoder.py to accommodate having a finalize step.
Change: 153649516
Diffstat (limited to 'tensorflow/contrib/seq2seq/BUILD')
-rw-r--r-- | tensorflow/contrib/seq2seq/BUILD | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 011c3ba427..f1e39a1373 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -37,11 +37,14 @@ tf_custom_op_py_library( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", + "//tensorflow/python:platform", "//tensorflow/python:rnn", "//tensorflow/python:rnn_cell", + "//tensorflow/python:script_ops", "//tensorflow/python:tensor_array_ops", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//third_party/py/numpy", ], ) @@ -156,6 +159,27 @@ cuda_py_test( ) cuda_py_test( + name = "beam_search_decoder_test", + size = "small", + srcs = ["python/kernel_tests/beam_search_decoder_test.py"], + additional_deps = [ + ":seq2seq_py", + "//third_party/py/numpy", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:rnn", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + +cuda_py_test( name = "attention_wrapper_test", size = "medium", srcs = ["python/kernel_tests/attention_wrapper_test.py"], |