diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_gpu_backend.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_gpu_backend.cc | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index 62168b6483..dc98d4fda6 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" @@ -22,8 +23,16 @@ namespace tensorflow { bool GpuOpFilter(KernelDef* kdef) { // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to // slow code. - if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || - kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") { + legacy_flags::BackendRegistrationFlags* flags = + legacy_flags::GetBackendRegistrationFlags(); + VLOG(2) << "flags->tf_enable_prng_ops_gpu: " << flags->tf_enable_prng_ops_gpu; + if (!flags->tf_enable_prng_ops_gpu && + (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || + kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal")) { + return false; + } + // TODO(b/26783907): The GPU backend currently does not implement sort. + if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") { return false; } if (kdef->op() == "Const") { |