diff options
author | 2017-02-07 10:37:25 -0800 | |
---|---|---|
committer | 2017-02-07 10:44:57 -0800 | |
commit | 9b40599da2d13f9c36861c6525e9c45d699aabfb (patch) | |
tree | 082e076d97da8e7e3dcbf47f1a14fb5e21b7a3df /tensorflow/contrib/factorization/BUILD | |
parent | 8ff1c465c87fc3967c9d480646fac6d6205f856c (diff) |
Add masked_matmul_ops to tensorflow.
Change: 146803360
Diffstat (limited to 'tensorflow/contrib/factorization/BUILD')
-rw-r--r-- | tensorflow/contrib/factorization/BUILD | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 4f613bc5b7..5f09851360 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -12,6 +12,7 @@ load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "factorization_py", @@ -76,6 +77,7 @@ tf_custom_op_library( "ops/factorization_ops.cc", ], deps = [ + "//tensorflow/contrib/factorization/kernels:masked_matmul_ops", "//tensorflow/contrib/factorization/kernels:wals_solver_ops", ], ) @@ -195,6 +197,40 @@ tf_py_test( ], ) +tf_py_test( + name = "masked_matmul_ops_test", + srcs = ["python/kernel_tests/masked_matmul_ops_test.py"], + additional_deps = [ + ":gen_factorization_ops", + ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "masked_matmul_benchmark", + srcs = ["python/kernel_tests/masked_matmul_benchmark.py"], + additional_deps = [ + ":gen_factorization_ops", + ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:random_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:variables", + ], + main = "python/kernel_tests/masked_matmul_benchmark.py", +) + # All files filegroup( name = "all_files", |