aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/BUILD
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-19 15:22:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-19 16:44:19 -0700
commit72164f00a95eaa8468fe00e84fe356c753e2c164 (patch)
treebf3be6c35e282867095089336599abe351a3a559 /tensorflow/contrib/seq2seq/BUILD
parent4a61e0bf03542c698ea045acb71e8f6e48db60d7 (diff)
BeamSearchDecoder:
1. Implemented in contrib/seq2seq. 2. Includes changes to attention_wrapper.py to accommodate the batch_size tiling. 3. Includes changes to decoder.py to accommodate having a finalize step. Change: 153649516
Diffstat (limited to 'tensorflow/contrib/seq2seq/BUILD')
-rw-r--r--tensorflow/contrib/seq2seq/BUILD24
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD
index 011c3ba427..f1e39a1373 100644
--- a/tensorflow/contrib/seq2seq/BUILD
+++ b/tensorflow/contrib/seq2seq/BUILD
@@ -37,11 +37,14 @@ tf_custom_op_py_library(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
"//tensorflow/python:rnn",
"//tensorflow/python:rnn_cell",
+ "//tensorflow/python:script_ops",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
+ "//third_party/py/numpy",
],
)
@@ -156,6 +159,27 @@ cuda_py_test(
)
cuda_py_test(
+ name = "beam_search_decoder_test",
+ size = "small",
+ srcs = ["python/kernel_tests/beam_search_decoder_test.py"],
+ additional_deps = [
+ ":seq2seq_py",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/contrib/rnn:rnn_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:rnn",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
+cuda_py_test(
name = "attention_wrapper_test",
size = "medium",
srcs = ["python/kernel_tests/attention_wrapper_test.py"],