diff options
author | 2017-10-08 16:18:24 -0700 | |
---|---|---|
committer | 2017-10-08 16:23:00 -0700 | |
commit | e0924e0577fe42b455be5fb881647fa64ea5b7c3 (patch) | |
tree | a5f2c7d23a3ddd576d0593feef12972fe6a70346 | |
parent | cab4f6f615e259546a1c0719a32d019730b2ee71 (diff) |
[TFXLA] Don't discard status unless it is NotFound.
PiperOrigin-RevId: 171477807
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 19 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 99 |
3 files changed, 90 insertions, 30 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 8521d4167a..1cd96fc4e2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -92,7 +92,6 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) } local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), - FunctionDefLibrary{})); local_pflr_.reset(new ProcessFunctionLibraryRuntime( &device_mgr_, Env::Default(), options.graph_def_version, @@ -142,8 +141,17 @@ Status XlaCompiler::CompileFunction( } const FunctionBody* fbody; - if (!GetFunctionBody(function, local_flib_runtime_, &fbody).ok()) { - TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_, &fbody)); + // The function may be in either the local_flib_runtime_ or flib_runtime_. + // Look up the function in local first and if it is not found then look up the + // function in flib_runtime_. + auto status = GetFunctionBody(function, local_flib_runtime_, &fbody); + if (!status.ok()) { + if (!errors::IsNotFound(status)) { + return status; + } + TF_RETURN_WITH_CONTEXT_IF_ERROR( + GetFunctionBody(function, flib_runtime_, &fbody), + "Local lookup failed with: ", status.error_message()); } TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args)); @@ -509,7 +517,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->requires_runtime_context = context->has_context_parameter(); // Tuple arguments and runtime context parameters are incompatible. - CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); + TF_RET_CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; @@ -546,7 +554,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, i < context->retvals().size(); ++i) { const XlaExpression& retval = context->retvals()[i]; if (!retval.has_constant_value()) { - CHECK_LT(computation_output, num_computation_outputs); + TF_RET_CHECK(computation_output < num_computation_outputs) + << "Computation has more outputs than expected"; OutputDescription& output = result->outputs[i]; output.is_constant = false; TF_RETURN_IF_ERROR(XLAShapeToTensorShape( diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 35159dbad4..addea74fc2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -287,6 +287,8 @@ class XlaCompiler { FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } private: + friend class XlaCompilerTest; + Options options_; // Status set to non-OK in the constructor if initialization fails. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 531725a623..9af557e23c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" @@ -36,6 +37,37 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { + +class XlaCompilerTest : public ::testing::Test { + protected: + XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} + + void SetUp() override { + client_ = xla::ClientLibrary::LocalClientOrDie(); + + XlaOpRegistry::RegisterCompilationKernels(); + + FunctionDefLibrary flib; + flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); + } + + XlaCompiler::Options DefaultOptions() { + XlaCompiler::Options options; + options.device_type = &cpu_device_type_; + options.client = client_; + options.flib_def = flib_def_.get(); + return options; + } + + FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) { + return compiler->local_flib_def_.get(); + } + + DeviceType cpu_device_type_; + xla::Client* client_; + std::unique_ptr<FunctionLibraryDefinition> flib_def_; +}; + namespace { // Helper class to test the ability to pass resources through to XLA @@ -125,31 +157,6 @@ REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT), REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT), DummyDuplicateOp); -class XlaCompilerTest : public ::testing::Test { - protected: - XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} - - void SetUp() override { - client_ = xla::ClientLibrary::LocalClientOrDie(); - - XlaOpRegistry::RegisterCompilationKernels(); - - FunctionDefLibrary flib; - flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); - } - - XlaCompiler::Options DefaultOptions() { - XlaCompiler::Options options; - options.device_type = &cpu_device_type_; - options.client = client_; - options.flib_def = flib_def_.get(); - return options; - } - - DeviceType cpu_device_type_; - xla::Client* client_; - std::unique_ptr<FunctionLibraryDefinition> flib_def_; -}; // Tests compilation and execution of an empty graph. TEST_F(XlaCompilerTest, EmptyReturnValues) { @@ -489,5 +496,47 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { EXPECT_EQ(1, result.resource_updates.size()); } +// Tests CompileFunction with undefined function fails. +TEST_F(XlaCompilerTest, UndefinedFunctionFails) { + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + XlaCompiler::CompilationResult result; + NameAttrList name_attr; + name_attr.set_name("Function_NotDefined_"); + Status status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, + /*args=*/{}, &result); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + << status.error_message(); +} + +// Tests CompileFunction with a local function lookup failing, fails with +// informative error about both lookups. +TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { + XlaCompiler compiler(DefaultOptions()); + + auto local_flib_def = LocalFlibDef(&compiler); + TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo())); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + XlaCompiler::CompilationResult result; + NameAttrList name_attr; + name_attr.set_name("XTimesTwo"); + Status status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, + /*args=*/{}, &result); + + ASSERT_FALSE(status.ok()); + // Flib lookup failure. + EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + << status.error_message(); + // Local flib lookup failure. + EXPECT_TRUE( + StringPiece(status.error_message()).contains("Attr T is not found")) + << status.error_message(); +} + } // namespace } // namespace tensorflow |