diff options
author | Derek Murray <mrry@google.com> | 2018-01-09 11:56:51 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-09 12:01:40 -0800 |
commit | 814e4a7830b506ed26ed22b9b1bc7233d6185467 (patch) | |
tree | 0796671350f5648df2c00740e952174facda88a6 /tensorflow/core/common_runtime/function_test.cc | |
parent | 68cb86ed592d714beabf71402322c9de0e611a69 (diff) |
Add experimental `FunctionLibraryRuntime::InstantiateOptions::overlay_lib`.
This option makes it possible to instantiate functions from a library
that has been loaded separately from the runtime's own library. We
plan to use this as part of the `tf.data` checkpoint restore process,
which might load an iterator whose state includes functions that
aren't present in the original graph. (This is currently achieved by
creating an isolated `FunctionLibraryRuntime` for each function-using
`Dataset`, but that is inefficient and prevents using features of the
main runtime, such as cross-device function calls.)
PiperOrigin-RevId: 181352217
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 58 |
1 files changed, 53 insertions, 5 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 2dacacea7b..853484d520 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -205,8 +205,18 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { test::function::Attrs attrs, const std::vector<Tensor>& args, std::vector<Tensor*> rets) { + return InstantiateAndRun(flr, name, attrs, + FunctionLibraryRuntime::InstantiateOptions(), args, + std::move(rets)); + } + + Status InstantiateAndRun( + FunctionLibraryRuntime* flr, const string& name, + test::function::Attrs attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + const std::vector<Tensor>& args, std::vector<Tensor*> rets) { FunctionLibraryRuntime::Handle handle; - Status status = flr->Instantiate(name, attrs, &handle); + Status status = flr->Instantiate(name, attrs, options, &handle); if (!status.ok()) { return status; } @@ -369,6 +379,42 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesN) { test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64})); } +TEST_F(FunctionLibraryRuntimeTest, XTimesNInOverlayLib) { + Init({}); + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + *proto.add_function() = test::function::XTimesFour(); + *proto.add_function() = test::function::XTimes16(); + std::unique_ptr<FunctionLibraryDefinition> overlay_lib( + new FunctionLibraryDefinition(OpRegistry::Global(), proto)); + + FunctionLibraryRuntime::InstantiateOptions options; + options.overlay_lib = overlay_lib.get(); + + auto x = test::AsTensor<float>({1, 2, 3, 4}); + Tensor y; + + // Ensure that the function is not installed in the base library. + HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, + {} /* options */, {x}, {&y}), + "Not found: Function XTimesTwo is not defined."); + + TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options, + {x}, {&y})); + test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); + TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, options, + {x}, {&y})); + test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16})); + TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, options, + {x}, {&y})); + test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64})); + + // Ensure that the use of the overlay has not leaked into the base library. + HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, + {} /* options */, {x}, {&y}), + "Not found: Function XTimesTwo is not defined."); +} + TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); @@ -640,7 +686,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto x4_x2_scale = ops::Const<float>( - s.WithOpName("x4/x2/scale/_12__cf__4") + s.WithOpName("x4/x2/scale/_12__cf__6") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale); @@ -846,13 +892,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1); auto scale = ops::Const( - s.WithOpName("scale/_5__cf__8") + s.WithOpName("scale/_5__cf__10") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale); auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x); auto const0 = ops::Const( - s.WithOpName("Func/_1/sy/_6__cf__9") + s.WithOpName("Func/_1/sy/_6__cf__11") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 0, {0}); auto func1_rx = ops::internal::BroadcastGradientArgs( @@ -1090,8 +1136,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { Init({test::function::FindDevice()}); + FunctionLibraryRuntime::InstantiateOptions instantiate_opts; + instantiate_opts.target = "/device:CPU:1"; FunctionLibraryRuntime::Handle handle; - TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, {"/device:CPU:1"}, &handle)); + TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, instantiate_opts, &handle)); Tensor y; FunctionLibraryRuntime::Options opts; |