aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function_test.cc
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-01-09 11:56:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-09 12:01:40 -0800
commit814e4a7830b506ed26ed22b9b1bc7233d6185467 (patch)
tree0796671350f5648df2c00740e952174facda88a6 /tensorflow/core/common_runtime/function_test.cc
parent68cb86ed592d714beabf71402322c9de0e611a69 (diff)
Add experimental `FunctionLibraryRuntime::InstantiateOptions::overlay_lib`.
This option makes it possible to instantiate functions from a library that has been loaded separately from the runtime's own library. We plan to use this as part of the `tf.data` checkpoint restore process, which might load an iterator whose state includes functions that aren't present in the original graph. (This is currently achieved by creating an isolated `FunctionLibraryRuntime` for each function-using `Dataset`, but that is inefficient and prevents using features of the main runtime, such as cross-device function calls.) PiperOrigin-RevId: 181352217
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc58
1 files changed, 53 insertions, 5 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 2dacacea7b..853484d520 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -205,8 +205,18 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
test::function::Attrs attrs,
const std::vector<Tensor>& args,
std::vector<Tensor*> rets) {
+ return InstantiateAndRun(flr, name, attrs,
+ FunctionLibraryRuntime::InstantiateOptions(), args,
+ std::move(rets));
+ }
+
+ Status InstantiateAndRun(
+ FunctionLibraryRuntime* flr, const string& name,
+ test::function::Attrs attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options,
+ const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
FunctionLibraryRuntime::Handle handle;
- Status status = flr->Instantiate(name, attrs, &handle);
+ Status status = flr->Instantiate(name, attrs, options, &handle);
if (!status.ok()) {
return status;
}
@@ -369,6 +379,42 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesN) {
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64}));
}
+TEST_F(FunctionLibraryRuntimeTest, XTimesNInOverlayLib) {
+ Init({});
+ FunctionDefLibrary proto;
+ *proto.add_function() = test::function::XTimesTwo();
+ *proto.add_function() = test::function::XTimesFour();
+ *proto.add_function() = test::function::XTimes16();
+ std::unique_ptr<FunctionLibraryDefinition> overlay_lib(
+ new FunctionLibraryDefinition(OpRegistry::Global(), proto));
+
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.overlay_lib = overlay_lib.get();
+
+ auto x = test::AsTensor<float>({1, 2, 3, 4});
+ Tensor y;
+
+ // Ensure that the function is not installed in the base library.
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
+ {} /* options */, {x}, {&y}),
+ "Not found: Function XTimesTwo is not defined.");
+
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
+ {x}, {&y}));
+ test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, options,
+ {x}, {&y}));
+ test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, options,
+ {x}, {&y}));
+ test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64}));
+
+ // Ensure that the use of the overlay has not leaked into the base library.
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
+ {} /* options */, {x}, {&y}),
+ "Not found: Function XTimesTwo is not defined.");
+}
+
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
@@ -640,7 +686,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
- s.WithOpName("x4/x2/scale/_12__cf__4")
+ s.WithOpName("x4/x2/scale/_12__cf__6")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -846,13 +892,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
- s.WithOpName("scale/_5__cf__8")
+ s.WithOpName("scale/_5__cf__10")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
- s.WithOpName("Func/_1/sy/_6__cf__9")
+ s.WithOpName("Func/_1/sy/_6__cf__11")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(
@@ -1090,8 +1136,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
Init({test::function::FindDevice()});
+ FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
+ instantiate_opts.target = "/device:CPU:1";
FunctionLibraryRuntime::Handle handle;
- TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, {"/device:CPU:1"}, &handle));
+ TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, instantiate_opts, &handle));
Tensor y;
FunctionLibraryRuntime::Options opts;