aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_gpu_backend.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_gpu_backend.cc13
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
index 62168b6483..dc98d4fda6 100644
--- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
@@ -22,8 +23,16 @@ namespace tensorflow {
bool GpuOpFilter(KernelDef* kdef) {
// TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to
// slow code.
- if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" ||
- kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") {
+ legacy_flags::BackendRegistrationFlags* flags =
+ legacy_flags::GetBackendRegistrationFlags();
+ VLOG(2) << "flags->tf_enable_prng_ops_gpu: " << flags->tf_enable_prng_ops_gpu;
+ if (!flags->tf_enable_prng_ops_gpu &&
+ (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" ||
+ kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal")) {
+ return false;
+ }
+ // TODO(b/26783907): The GPU backend currently does not implement sort.
+ if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") {
return false;
}
if (kdef->op() == "Const") {