aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2017-10-08 16:18:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-08 16:23:00 -0700
commite0924e0577fe42b455be5fb881647fa64ea5b7c3 (patch)
treea5f2c7d23a3ddd576d0593feef12972fe6a70346
parentcab4f6f615e259546a1c0719a32d019730b2ee71 (diff)
[TFXLA] Don't discard status unless it is NotFound.
PiperOrigin-RevId: 171477807
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc19
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc99
3 files changed, 90 insertions, 30 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 8521d4167a..1cd96fc4e2 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -92,7 +92,6 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
}
local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
-
FunctionDefLibrary{}));
local_pflr_.reset(new ProcessFunctionLibraryRuntime(
&device_mgr_, Env::Default(), options.graph_def_version,
@@ -142,8 +141,17 @@ Status XlaCompiler::CompileFunction(
}
const FunctionBody* fbody;
- if (!GetFunctionBody(function, local_flib_runtime_, &fbody).ok()) {
- TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_, &fbody));
+ // The function may be in either the local_flib_runtime_ or flib_runtime_.
+ // Look up the function in local first and if it is not found then look up the
+ // function in flib_runtime_.
+ auto status = GetFunctionBody(function, local_flib_runtime_, &fbody);
+ if (!status.ok()) {
+ if (!errors::IsNotFound(status)) {
+ return status;
+ }
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ GetFunctionBody(function, flib_runtime_, &fbody),
+ "Local lookup failed with: ", status.error_message());
}
TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args));
@@ -509,7 +517,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
result->requires_runtime_context = context->has_context_parameter();
// Tuple arguments and runtime context parameters are incompatible.
- CHECK(!(options.use_tuple_arg && result->requires_runtime_context));
+ TF_RET_CHECK(!(options.use_tuple_arg && result->requires_runtime_context));
VLOG(2) << "Outputs: total: " << context->retvals().size()
<< " nonconstant: " << num_nonconst_outputs;
@@ -546,7 +554,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
i < context->retvals().size(); ++i) {
const XlaExpression& retval = context->retvals()[i];
if (!retval.has_constant_value()) {
- CHECK_LT(computation_output, num_computation_outputs);
+ TF_RET_CHECK(computation_output < num_computation_outputs)
+ << "Computation has more outputs than expected";
OutputDescription& output = result->outputs[i];
output.is_constant = false;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 35159dbad4..addea74fc2 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -287,6 +287,8 @@ class XlaCompiler {
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
private:
+ friend class XlaCompilerTest;
+
Options options_;
// Status set to non-OK in the constructor if initialization fails.
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 531725a623..9af557e23c 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph.h"
@@ -36,6 +37,37 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
namespace tensorflow {
+
+class XlaCompilerTest : public ::testing::Test {
+ protected:
+ XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {}
+
+ void SetUp() override {
+ client_ = xla::ClientLibrary::LocalClientOrDie();
+
+ XlaOpRegistry::RegisterCompilationKernels();
+
+ FunctionDefLibrary flib;
+ flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
+ }
+
+ XlaCompiler::Options DefaultOptions() {
+ XlaCompiler::Options options;
+ options.device_type = &cpu_device_type_;
+ options.client = client_;
+ options.flib_def = flib_def_.get();
+ return options;
+ }
+
+ FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) {
+ return compiler->local_flib_def_.get();
+ }
+
+ DeviceType cpu_device_type_;
+ xla::Client* client_;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+};
+
namespace {
// Helper class to test the ability to pass resources through to XLA
@@ -125,31 +157,6 @@ REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT),
REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT),
DummyDuplicateOp);
-class XlaCompilerTest : public ::testing::Test {
- protected:
- XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {}
-
- void SetUp() override {
- client_ = xla::ClientLibrary::LocalClientOrDie();
-
- XlaOpRegistry::RegisterCompilationKernels();
-
- FunctionDefLibrary flib;
- flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
- }
-
- XlaCompiler::Options DefaultOptions() {
- XlaCompiler::Options options;
- options.device_type = &cpu_device_type_;
- options.client = client_;
- options.flib_def = flib_def_.get();
- return options;
- }
-
- DeviceType cpu_device_type_;
- xla::Client* client_;
- std::unique_ptr<FunctionLibraryDefinition> flib_def_;
-};
// Tests compilation and execution of an empty graph.
TEST_F(XlaCompilerTest, EmptyReturnValues) {
@@ -489,5 +496,47 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
EXPECT_EQ(1, result.resource_updates.size());
}
+// Tests CompileFunction with undefined function fails.
+TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
+ XlaCompiler compiler(DefaultOptions());
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ XlaCompiler::CompilationResult result;
+ NameAttrList name_attr;
+ name_attr.set_name("Function_NotDefined_");
+ Status status =
+ compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
+ /*args=*/{}, &result);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined."))
+ << status.error_message();
+}
+
+// Tests CompileFunction with a local function lookup failing, fails with
+// informative error about both lookups.
+TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
+ XlaCompiler compiler(DefaultOptions());
+
+ auto local_flib_def = LocalFlibDef(&compiler);
+ TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo()));
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ XlaCompiler::CompilationResult result;
+ NameAttrList name_attr;
+ name_attr.set_name("XTimesTwo");
+ Status status =
+ compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
+ /*args=*/{}, &result);
+
+ ASSERT_FALSE(status.ok());
+ // Flib lookup failure.
+ EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined."))
+ << status.error_message();
+ // Local flib lookup failure.
+ EXPECT_TRUE(
+ StringPiece(status.error_message()).contains("Attr T is not found"))
+ << status.error_message();
+}
+
} // namespace
} // namespace tensorflow