aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-03 14:03:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-03 15:38:21 -0700
commit965d620104d375c5fd2b18881f353eb41d9a63a2 (patch)
treea801f0e211bf6ad5eb81536eea9343edf0544dfa /tensorflow/core/common_runtime/function_test.cc
parent7828637e07b0081a37dfdc66ff912dd1d6ff3228 (diff)
Internal change.
Change: 155009390
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc1036
1 files changed, 630 insertions, 406 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 8f70ab8783..af1ff6aec0 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -17,6 +17,10 @@ limitations under the License.
#include <atomic>
+#include "tensorflow/cc/ops/array_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/functional_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
@@ -28,10 +32,12 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
+#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
@@ -58,13 +64,8 @@ class FunctionTest : public ::testing::Test {
: device_(DeviceFactory::NewDevice("CPU", {},
"/job:localhost/replica:0/task:0")) {}
- ~FunctionTest() override {
- delete exec_;
- delete device_;
- }
-
void Create(const FunctionDef& fdef, InstantiateAttrValueSlice attrs) {
- delete exec_;
+ exec_ = nullptr;
InstantiationResult result;
TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result));
@@ -79,15 +80,18 @@ class FunctionTest : public ::testing::Test {
const int version = g->versions().producer();
LocalExecutorParams params;
- params.device = device_;
+ params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef,
OpKernel** kernel) {
- return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
+ return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
+ kernel);
};
params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel);
};
- TF_CHECK_OK(NewLocalExecutor(params, g, &exec_));
+ Executor* exec;
+ TF_CHECK_OK(NewLocalExecutor(params, g, &exec));
+ exec_.reset(exec);
}
void Run(const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
@@ -105,8 +109,8 @@ class FunctionTest : public ::testing::Test {
}
}
- Device* device_ = nullptr;
- Executor* exec_ = nullptr;
+ std::unique_ptr<Device> device_;
+ std::unique_ptr<Executor> exec_;
DataTypeVector arg_types_;
DataTypeVector ret_types_;
};
@@ -136,21 +140,15 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
: device_(DeviceFactory::NewDevice("CPU", {},
"/job:localhost/replica:0/task:0")) {}
- ~FunctionLibraryRuntimeTest() override {
- delete lib_;
- delete lib_def_;
- delete device_;
- }
-
void Init(const std::vector<FunctionDef>& flib) {
FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
- delete lib_def_;
- lib_def_ = new FunctionLibraryDefinition(OpRegistry::Global(), proto);
- delete lib_;
+ lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts;
- lib_ = NewFunctionLibraryRuntime(nullptr, Env::Default(), device_,
- TF_GRAPH_DEF_VERSION, lib_def_, opts);
+ lib_.reset(NewFunctionLibraryRuntime(nullptr, Env::Default(), device_.get(),
+ TF_GRAPH_DEF_VERSION, lib_def_.get(),
+ opts));
+ fdef_lib_ = lib_def_->ToProto();
}
Status Run(const string& name, InstantiateAttrValueSlice attrs,
@@ -190,7 +188,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return Status::OK();
}
- Graph* GetFuncBody(const string& name, InstantiateAttrValueSlice attrs) {
+ std::unique_ptr<Graph> GetFuncBody(const string& name,
+ InstantiateAttrValueSlice attrs) {
FunctionLibraryRuntime::Handle handle;
Status status = lib_->Instantiate(name, attrs, &handle);
if (!status.ok()) {
@@ -199,12 +198,13 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
}
const FunctionBody* fbody = lib_->GetFunctionBody(handle);
CHECK_NOTNULL(fbody);
- Graph* ret = new Graph(lib_def_);
- CopyGraph(*fbody->graph, ret);
+ std::unique_ptr<Graph> ret(new Graph(lib_def_.get()));
+ CopyGraph(*fbody->graph, ret.get());
return ret;
}
- Graph* GetGradBody(const string& func, InstantiateAttrValueSlice attrs) {
+ std::unique_ptr<Graph> GetGradBody(const string& func,
+ InstantiateAttrValueSlice attrs) {
FunctionLibraryRuntime::Handle handle;
Status status = lib_->Instantiate(func, attrs, &handle);
if (!status.ok()) {
@@ -213,17 +213,17 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
}
const FunctionBody* fbody = lib_->GetFunctionBody(handle);
CHECK_NOTNULL(fbody);
- FunctionBody* gbody = SymbolicGradient(*fbody);
+ std::unique_ptr<FunctionBody> gbody(SymbolicGradient(*fbody));
CHECK_NOTNULL(gbody);
- Graph* ret = new Graph(lib_def_);
- CopyGraph(*gbody->graph, ret);
- delete gbody;
+ std::unique_ptr<Graph> ret(new Graph(lib_def_.get()));
+ CopyGraph(*gbody->graph, ret.get());
return ret;
}
- Device* device_ = nullptr;
- FunctionLibraryDefinition* lib_def_ = nullptr;
- FunctionLibraryRuntime* lib_ = nullptr;
+ std::unique_ptr<Device> device_;
+ std::unique_ptr<FunctionLibraryDefinition> lib_def_;
+ std::unique_ptr<FunctionLibraryRuntime> lib_;
+ FunctionDefLibrary fdef_lib_;
};
TEST_F(FunctionLibraryRuntimeTest, IsStateful) {
@@ -254,113 +254,174 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesN) {
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64}));
}
+// Adds a function call to 'scope.
+// TODO(phawkins): replace with C++ API for calling functions, when that exists.
+Output Call(Scope* scope, const string& op_name, const string& fn_name,
+ gtl::ArraySlice<Input> inputs) {
+ NodeDef def;
+ NodeDefBuilder builder(op_name, fn_name, scope->graph()->op_registry());
+ for (const Input& input : inputs) {
+ builder.Input(input.node()->name(), input.index(),
+ input.node()->output_type(input.index()));
+ }
+ TF_CHECK_OK(builder.Finalize(&def));
+ Status status;
+ Node* n = scope->graph()->AddNode(def, &status);
+ TF_CHECK_OK(status);
+ for (int i = 0; i < inputs.size(); ++i) {
+ scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i);
+ }
+ return Output(n);
+}
+
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
- Graph* g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}});
+ std::unique_ptr<Graph> g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}});
ASSERT_TRUE(g != nullptr);
- const char* e0 = R"P(
-(n2:float) -> (n4:float) {
- n3 = XTimesFour[T=float](n2)
- n4 = XTimesFour[T=float](n3)
-}
-)P";
- EXPECT_EQ(e0, DebugString(g));
-
- ExpandInlineFunctions(lib_, g);
- const char* e1 = R"P(
-(n2:float) -> (n17:float) {
- n10 = Identity[T=float](n2)
- n7 = XTimesTwo[T=float](n10)
- n8 = XTimesTwo[T=float](n7)
- n11 = Identity[T=float](n8)
- n16 = Identity[T=float](n11)
- n13 = XTimesTwo[T=float](n16)
- n14 = XTimesTwo[T=float](n13)
- n17 = Identity[T=float](n14)
-}
-)P";
- EXPECT_EQ(e1, DebugString(g));
-
- ExpandInlineFunctions(lib_, g);
- const char* e2 = R"P(
-(n2:float) -> (n17:float) {
- n18 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
- n25 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
- n32 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
- n39 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
- n19 = Cast[DstT=float, SrcT=int64](n18)
- n26 = Cast[DstT=float, SrcT=int64](n25)
- n33 = Cast[DstT=float, SrcT=int64](n32)
- n40 = Cast[DstT=float, SrcT=int64](n39)
- n10 = Identity[T=float](n2)
- n23 = Identity[T=float](n10)
- n21 = Mul[T=float](n23, n19)
- n24 = Identity[T=float](n21)
- n30 = Identity[T=float](n24)
- n28 = Mul[T=float](n30, n26)
- n31 = Identity[T=float](n28)
- n11 = Identity[T=float](n31)
- n16 = Identity[T=float](n11)
- n37 = Identity[T=float](n16)
- n35 = Mul[T=float](n37, n33)
- n38 = Identity[T=float](n35)
- n44 = Identity[T=float](n38)
- n42 = Mul[T=float](n44, n40)
- n45 = Identity[T=float](n42)
- n17 = Identity[T=float](n45)
-}
-)P";
- EXPECT_EQ(e2, DebugString(g));
+
+ {
+ Scope s = Scope::NewRootScope();
+ TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
+ auto arg = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto a = Call(&s, "x4", "XTimesFour", {arg});
+ auto b = Call(&s, "y", "XTimesFour", {a});
+ auto ret = ops::_Retval(s.WithOpName("y_RetVal"), b, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
+
+ ExpandInlineFunctions(lib_.get(), g.get());
+ {
+ Scope s = Scope::NewRootScope();
+ TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto func0 = ops::Identity(s.WithOpName("Func/_0"), x);
+ auto x4_x2 = Call(&s, "x4/x2", "XTimesTwo", {func0});
+ auto x4_y = Call(&s, "x4/y", "XTimesTwo", {x4_x2});
+ auto func1 = ops::Identity(s.WithOpName("Func/_1"), x4_y);
+ auto func2 = ops::Identity(s.WithOpName("Func/_2"), func1);
+ auto y_x2 = Call(&s, "y/x2", "XTimesTwo", {func2});
+ auto y_y = Call(&s, "y/y", "XTimesTwo", {y_x2});
+ auto func3 = ops::Identity(s.WithOpName("Func/_3"), y_y);
+ auto ret = ops::_Retval(s.WithOpName("y_RetVal"), func3, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
+
+ ExpandInlineFunctions(lib_.get(), g.get());
+ GraphDef e2;
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto x4_x2_two = ops::Const<int64>(s.WithOpName("x4/x2/two"), 2LL);
+ auto x4_y_two = ops::Const<int64>(s.WithOpName("x4/y/two"), 2LL);
+ auto y_x2_two = ops::Const<int64>(s.WithOpName("y/x2/two"), 2LL);
+ auto y_y_two = ops::Const<int64>(s.WithOpName("y/y/two"), 2LL);
+ auto x4_x2_scale =
+ ops::Cast(s.WithOpName("x4/x2/scale"), x4_x2_two, DT_FLOAT);
+ auto x4_y_scale = ops::Cast(s.WithOpName("x4/y/scale"), x4_y_two, DT_FLOAT);
+ auto y_x2_scale = ops::Cast(s.WithOpName("y/x2/scale"), y_x2_two, DT_FLOAT);
+ auto y_y_scale = ops::Cast(s.WithOpName("y/y/scale"), y_y_two, DT_FLOAT);
+ auto func0 = ops::Identity(s.WithOpName("Func/_0"), x);
+ auto func4 = ops::Identity(s.WithOpName("Func/_4"), func0);
+ auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), func4, x4_x2_scale);
+ auto func5 = ops::Identity(s.WithOpName("Func/_5"), x4_x2_y);
+ auto func6 = ops::Identity(s.WithOpName("Func/_6"), func5);
+ auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), func6, x4_y_scale);
+ auto func7 = ops::Identity(s.WithOpName("Func/_7"), x4_y_y);
+ auto func1 = ops::Identity(s.WithOpName("Func/_1"), func7);
+ auto func2 = ops::Identity(s.WithOpName("Func/_2"), func1);
+ auto func8 = ops::Identity(s.WithOpName("Func/_8"), func2);
+ auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), func8, y_x2_scale);
+ auto func9 = ops::Identity(s.WithOpName("Func/_9"), y_x2_y);
+ auto func10 = ops::Identity(s.WithOpName("Func/_10"), func9);
+ auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), func10, y_y_scale);
+ auto func11 = ops::Identity(s.WithOpName("Func/_11"), y_y_y);
+ auto func3 = ops::Identity(s.WithOpName("Func/_3"), func11);
+ auto ret = ops::_Retval(s.WithOpName("y_RetVal"), func3, 0);
+ TF_ASSERT_OK(s.ToGraphDef(&e2));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(e2, actual);
+ }
// No further inlining.
- ExpandInlineFunctions(lib_, g);
- EXPECT_EQ(e2, DebugString(g));
+ ExpandInlineFunctions(lib_.get(), g.get());
+ {
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(e2, actual);
+ }
// Get rid of redundant Identity nodes.
- RemoveIdentityNodes(g);
- const char* e3 = R"P(
-(n2:float) -> (n42:float) {
- n18 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
- n25 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
- n32 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
- n39 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
- n19 = Cast[DstT=float, SrcT=int64](n18)
- n26 = Cast[DstT=float, SrcT=int64](n25)
- n33 = Cast[DstT=float, SrcT=int64](n32)
- n40 = Cast[DstT=float, SrcT=int64](n39)
- n21 = Mul[T=float](n2, n19)
- n28 = Mul[T=float](n21, n26)
- n35 = Mul[T=float](n28, n33)
- n42 = Mul[T=float](n35, n40)
-}
-)P";
- EXPECT_EQ(e3, DebugString(g));
- delete g;
+ RemoveIdentityNodes(g.get());
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto x4_x2_two = ops::Const<int64>(s.WithOpName("x4/x2/two"), 2LL);
+ auto x4_y_two = ops::Const<int64>(s.WithOpName("x4/y/two"), 2LL);
+ auto y_x2_two = ops::Const<int64>(s.WithOpName("y/x2/two"), 2LL);
+ auto y_y_two = ops::Const<int64>(s.WithOpName("y/y/two"), 2LL);
+ auto x4_x2_scale =
+ ops::Cast(s.WithOpName("x4/x2/scale"), x4_x2_two, DT_FLOAT);
+ auto x4_y_scale = ops::Cast(s.WithOpName("x4/y/scale"), x4_y_two, DT_FLOAT);
+ auto y_x2_scale = ops::Cast(s.WithOpName("y/x2/scale"), y_x2_two, DT_FLOAT);
+ auto y_y_scale = ops::Cast(s.WithOpName("y/y/scale"), y_y_two, DT_FLOAT);
+ auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
+ auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), x4_x2_y, x4_y_scale);
+ auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), x4_y_y, y_x2_scale);
+ auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), y_x2_y, y_y_scale);
+ auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y_y_y, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
}
TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
- std::unique_ptr<Graph> g(GetFuncBody("XTimes16", {{"T", DT_FLOAT}}));
+ std::unique_ptr<Graph> g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}});
ASSERT_TRUE(g != nullptr);
- ExpandInlineFunctions(lib_, g.get());
- OptimizeGraph(lib_, &g);
- const char* e0 = R"P(
-(n2:float) -> (n7:float) {
- n8 = Const[dtype=float, value=Tensor<type: float shape: [] values: 2>]()
- n4 = Mul[T=float](n2, n8)
- n5 = Mul[T=float](n4, n8)
- n6 = Mul[T=float](n5, n8)
- n7 = Mul[T=float](n6, n8)
-}
-)P";
- EXPECT_EQ(e0, DebugString(g.get()));
+ ExpandInlineFunctions(lib_.get(), g.get());
+ OptimizeGraph(lib_.get(), &g);
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto x4_x2_scale = ops::Const<float>(
+ s.WithOpName("x4/x2/scale/_12__cf__2")
+ .WithDevice("/job:localhost/replica:0/task:0/cpu:0"),
+ 2.0f);
+ auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
+ auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), x4_x2_y, x4_x2_scale);
+ auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), x4_y_y, x4_x2_scale);
+ auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), y_x2_y, x4_x2_scale);
+ auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y_y_y, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
}
TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
auto func = FDH::Create( // Creates a FunctionDef using NodeDefs
- // Name
+ // Name
"ManySwapsNodeDef",
// Input
{"x: float", "y: float"},
@@ -379,9 +440,9 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
// Return
{{"o", "g:output"}});
Init({test::function::Swap(), func});
- std::unique_ptr<Graph> g(GetFuncBody("ManySwapsNodeDef", {}));
+ std::unique_ptr<Graph> g = GetFuncBody("ManySwapsNodeDef", {});
ASSERT_TRUE(g != nullptr);
- OptimizeGraph(lib_, &g);
+ OptimizeGraph(lib_.get(), &g);
const char* e0 = R"P(
(n3:float, n2:float) -> (n3:float) {
}
@@ -412,24 +473,35 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
{{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}},
{{"o", "o:z:0"}});
Init({test::function::Swap(), func});
- std::unique_ptr<Graph> g(GetFuncBody("ManySwapsFirst", {}));
+ std::unique_ptr<Graph> g = GetFuncBody("ManySwapsFirst", {});
ASSERT_TRUE(g != nullptr);
- OptimizeGraph(lib_, &g);
+ OptimizeGraph(lib_.get(), &g);
- // NOTE: We can remove n8, n9, n10, n11 with a control edge n8->n5.
+ // NOTE: We can remove func0, func1, func2, func9 with a control edge n8->n5.
// But we don't have a pass doing that.
- const char* e0 = R"P(
-(n3:float, n2:float) -> (n6:float) {
- n4 = Mul[T=float](n3, n3)
- n8 = NoOp() @ n4
- n9 = Identity[T=float](n3) @ n8
- n10 = Identity[T=float](n2) @ n8
- n11 = NoOp() @ n9, n10
- n5 = Mul[T=float](n2, n2) @ n11
- n6 = Add[T=float](n4, n5)
-}
-)P";
- EXPECT_EQ(e0, DebugString(g.get()));
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
+ auto x2 = ops::Mul(s.WithOpName("x2"), x, x);
+ auto func0 = ops::NoOp(s.WithOpName("Func/_0").WithControlDependencies(x2));
+ auto func1 = ops::Identity(
+ s.WithOpName("Func/_1").WithControlDependencies({func0}), x);
+ auto func2 = ops::Identity(
+ s.WithOpName("Func/_2").WithControlDependencies({func0}), y);
+ auto func9 = ops::NoOp(s.WithOpName("Func/_9").WithControlDependencies(
+ {func1.output.op(), func2.output.op()}));
+ auto y2 =
+ ops::Mul(s.WithOpName("y2").WithControlDependencies({func9}), y, y);
+ auto o = ops::Add(s.WithOpName("o"), x2, y2);
+ auto ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
}
TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) {
@@ -476,84 +548,136 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) {
TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
- auto f = GetFuncBody("XTimesTwo", {{"T", DT_FLOAT}});
- const char* e0 = R"P(
-(n4:float) -> (n5:float) {
- n2 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
- n3 = Cast[DstT=float, SrcT=int64](n2)
- n5 = Mul[T=float](n4, n3)
-}
-)P";
- EXPECT_EQ(e0, DebugString(f));
- delete f;
- std::unique_ptr<Graph> g(GetGradBody("XTimesTwo", {{"T", DT_FLOAT}}));
- const char* e1 = R"P(
-(n4:float, n6:float) -> (n7:float) {
- n2 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
- n3 = Cast[DstT=float, SrcT=int64](n2)
- n5 = Mul[T=float](n4, n3)
- n7 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Mul[T=float]](n4, n3, n6)
-}
-)P";
- EXPECT_EQ(e1, DebugString(g.get()));
-
- OptimizeGraph(lib_, &g);
- const char* e2 = R"P(
-(n2:float, n3:float) -> (n9:float) {
- n10 = Const[dtype=float, value=Tensor<type: float shape: [] values: 2>]()
- n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [0] values: >]()
- n6 = Shape[T=float, out_type=int32](n2)
- n5 = Mul[T=float](n3, n10)
- n7 = BroadcastGradientArgs[T=int32](n6, n11)
- n8 = Sum[T=float, Tidx=int32, keep_dims=false](n5, n7)
- n9 = Reshape[T=float, Tshape=int32](n8, n6)
-}
-)P";
- EXPECT_EQ(e2, DebugString(g.get()));
+ std::unique_ptr<Graph> f = GetFuncBody("XTimesTwo", {{"T", DT_FLOAT}});
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto two = ops::Const(s.WithOpName("two"), 2LL);
+ auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT);
+ auto y = ops::Mul(s.WithOpName("y"), x, scale);
+ auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ f->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
+
+ std::unique_ptr<Graph> g = GetGradBody("XTimesTwo", {{"T", DT_FLOAT}});
+
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
+ auto two = ops::Const(s.WithOpName("two"), 2LL);
+ auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT);
+ auto y = ops::Mul(s.WithOpName("y"), x, scale);
+ NameAttrList fn;
+ fn.set_name("Mul");
+ (*fn.mutable_attr())["T"].set_type(DT_FLOAT);
+ auto func1 = ops::SymbolicGradient(
+ s.WithOpName("Func/_1"), std::initializer_list<Input>{x, scale, func0},
+ {DT_FLOAT, DT_FLOAT}, fn);
+ auto func2 = ops::_Retval(s.WithOpName("Func/_2"), func1[0], 0);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
+
+ OptimizeGraph(lib_.get(), &g);
+
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
+ auto scale =
+ ops::Const(s.WithOpName("scale/_5__cf__6")
+ .WithDevice("/job:localhost/replica:0/task:0/cpu:0"),
+ 2.0f);
+ auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
+ auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
+ auto const0 =
+ ops::Const(s.WithOpName("Func/_1/sy/_6__cf__7")
+ .WithDevice("/job:localhost/replica:0/task:0/cpu:0"),
+ 0, {0});
+ auto func1_rx = ops::internal::BroadcastGradientArgs(
+ s.WithOpName("Func/_1/rx"), func1_sx, const0);
+ auto func1_sum_gx =
+ ops::Sum(s.WithOpName("Func/_1/sum_gx"), func1_gx, func1_rx.r0);
+ auto func1_dx =
+ ops::Reshape(s.WithOpName("Func/_1/dx"), func1_sum_gx, func1_sx);
+ auto func2 = ops::_Retval(s.WithOpName("Func/_2"), func1_dx, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
}
TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) {
Init({});
auto T = DT_FLOAT;
- auto g = GetFuncBody("SymbolicGradient",
- {{"f", FDH::FunctionRef("Add", {{"T", T}})}});
- const char* e0 = R"P(
-(n7:float, n5:float, n2:float) -> (n14:float, n11:float) {
- n3 = Identity[T=float](n2)
- n4 = Identity[T=float](n2)
- n6 = Shape[T=float, out_type=int32](n5)
- n8 = Shape[T=float, out_type=int32](n7)
- n9 = BroadcastGradientArgs[T=int32](n8, n6)
- n10 = Sum[T=float, Tidx=int32, keep_dims=false](n3, n9:1)
- n13 = Sum[T=float, Tidx=int32, keep_dims=false](n4, n9)
- n11 = Reshape[T=float, Tshape=int32](n10, n6)
- n14 = Reshape[T=float, Tshape=int32](n13, n8)
-}
-)P";
- EXPECT_EQ(e0, DebugString(g));
- delete g;
+ std::unique_ptr<Graph> g = GetFuncBody(
+ "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}});
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
+ auto dz = ops::_Arg(s.WithOpName("dz"), DT_FLOAT, 2);
+ auto gx = ops::Identity(s.WithOpName("gx"), dz);
+ auto gy = ops::Identity(s.WithOpName("gy"), dz);
+ auto sx = ops::Shape(s.WithOpName("sx"), x);
+ auto sy = ops::Shape(s.WithOpName("sy"), y);
+ auto rx = ops::internal::BroadcastGradientArgs(s.WithOpName("rx"), sx, sy);
+ auto sum_gx = ops::Sum(s.WithOpName("sum_gx"), gx, rx.r0);
+ auto sum_gy = ops::Sum(s.WithOpName("sum_gy"), gy, rx.r1);
+ auto dx = ops::Reshape(s.WithOpName("dx"), sum_gx, sx);
+ auto dy = ops::Reshape(s.WithOpName("dy"), sum_gy, sy);
+ auto dx_ret = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
+ auto dy_ret = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
}
TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) {
Init({});
auto T = DT_FLOAT;
- auto g = GetFuncBody("SymbolicGradient",
- {{"f", FDH::FunctionRef("Mul", {{"T", T}})}});
- const char* e0 = R"P(
-(n6:float, n3:float, n2:float) -> (n14:float, n11:float) {
- n4 = Mul[T=float](n2, n3)
- n5 = Shape[T=float, out_type=int32](n3)
- n7 = Mul[T=float](n6, n2)
- n8 = Shape[T=float, out_type=int32](n6)
- n9 = BroadcastGradientArgs[T=int32](n8, n5)
- n10 = Sum[T=float, Tidx=int32, keep_dims=false](n7, n9:1)
- n13 = Sum[T=float, Tidx=int32, keep_dims=false](n4, n9)
- n11 = Reshape[T=float, Tshape=int32](n10, n5)
- n14 = Reshape[T=float, Tshape=int32](n13, n8)
-}
-)P";
- EXPECT_EQ(e0, DebugString(g));
- delete g;
+ std::unique_ptr<Graph> g = GetFuncBody(
+ "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}});
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
+ auto dz = ops::_Arg(s.WithOpName("dz"), DT_FLOAT, 2);
+ auto gx = ops::Mul(s.WithOpName("gx"), dz, y);
+ auto sx = ops::Shape(s.WithOpName("sx"), x);
+ auto gy = ops::Mul(s.WithOpName("gy"), x, dz);
+ auto sy = ops::Shape(s.WithOpName("sy"), y);
+ auto rx = ops::internal::BroadcastGradientArgs(s.WithOpName("rx"), sx, sy);
+ auto sum_gx = ops::Sum(s.WithOpName("sum_gx"), gx, rx.r0);
+ auto sum_gy = ops::Sum(s.WithOpName("sum_gy"), gy, rx.r1);
+ auto dx = ops::Reshape(s.WithOpName("dx"), sum_gx, sx);
+ auto dy = ops::Reshape(s.WithOpName("dy"), sum_gy, sy);
+ auto dx_ret = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
+ auto dy_ret = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
}
TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
@@ -570,108 +694,170 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
});
// TestGrad = Test'(x, y)
- auto grad =
- FDH::Define("TestGrad", {"x:float", "y:float"}, {"dx:float", "dy:float"},
- {}, {FDH::Const<float>("dz", 1),
- {{"grad0", "grad1"},
- "SymbolicGradient",
- {"x", "y", "dz"},
- {
- {"f", FDH::FunctionRef("Test")},
- {"Tin", DataTypeSlice{T, T, T}},
- {"Tout", DataTypeSlice{T, T}},
- }},
- {{"dx"}, "Identity", {"grad0"}, {{"T", DT_FLOAT}}},
- {{"dy"}, "Identity", {"grad1"}, {{"T", DT_FLOAT}}}});
+ auto grad = FDH::Define("TestGrad", {"x:float", "y:float"},
+ {"dx:float", "dy:float"}, {},
+ {FDH::Const<float>("dz", 1),
+ {{"grad0", "grad1"},
+ "SymbolicGradient",
+ {"x", "y", "dz"},
+ {
+ {"f", FDH::FunctionRef("Test")},
+ {"Tin", DataTypeSlice{T, T, T}},
+ {"Tout", DataTypeSlice{T, T}},
+ }},
+ {{"dx"}, "Identity", {"grad0"}, {{"T", DT_FLOAT}}},
+ {{"dy"}, "Identity", {"grad1"}, {{"T", DT_FLOAT}}}});
Init({test, grad});
- std::unique_ptr<Graph> g(GetFuncBody("TestGrad", {}));
+ std::unique_ptr<Graph> g = GetFuncBody("TestGrad", {});
ASSERT_TRUE(g != nullptr);
- const char* e0 = R"P(
-(n4:float, n3:float) -> (n8:float, n6:float) {
- n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
- n5 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Test](n4, n3, n2)
- n6 = Identity[T=float](n5:1)
- n8 = Identity[T=float](n5)
-}
-)P";
- EXPECT_EQ(e0, DebugString(g.get()));
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
+ auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
+ NameAttrList fn;
+ fn.set_name("Test");
+ auto grad0 = ops::SymbolicGradient(s.WithOpName("grad0"),
+ std::initializer_list<Input>{x, y, dz},
+ {DT_FLOAT, DT_FLOAT}, fn);
+ auto dx = ops::Identity(s.WithOpName("dx"), grad0[0]);
+ auto dy = ops::Identity(s.WithOpName("dy"), grad0[1]);
+ auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
+ auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
- ExpandInlineFunctions(lib_, g.get());
- const char* e1 = R"P(
-(n4:float, n3:float) -> (n8:float, n6:float) {
- n10 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
- n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
- n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
- n26 = Identity[T=float](n2)
- n25 = Identity[T=float](n3)
- n24 = Identity[T=float](n4)
- n14 = Add[T=float](n24, n25)
- n15 = Rank[T=float](n14)
- n16 = Range[Tidx=int32](n11, n15, n10)
- n20 = ZerosLike[T=int32](n15)
- n17 = Sum[T=float, Tidx=int32, keep_dims=false](n14, n16)
- n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, Tidx=int32, keep_dims=false]](n14, n16, n26)
- n21 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Add[T=float]](n24, n25, n19)
- n27 = Identity[T=float](n21)
- n28 = Identity[T=float](n21:1)
- n8 = Identity[T=float](n27)
- n6 = Identity[T=float](n28)
-}
-)P";
- EXPECT_EQ(e1, DebugString(g.get()));
-
- OptimizeGraph(lib_, &g);
- const char* e2 = R"P(
-(n4:float, n3:float) -> (n25:float, n23:float) {
- n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
- n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
- n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
- n19 = Shape[T=float, out_type=int32](n3)
- n9 = Add[T=float](n4, n3)
- n20 = Shape[T=float, out_type=int32](n4)
- n10 = Rank[T=float](n9)
- n14 = Shape[T=float, out_type=int32](n9)
- n21 = BroadcastGradientArgs[T=int32](n20, n19)
- n11 = Range[Tidx=int32](n8, n10, n7)
- n12 = Shape[T=int32, out_type=int32](n11)
- n13 = Fill[T=int32](n12, n7)
- n15 = DynamicStitch[N=2, T=int32](n11, n11, n14, n13)
- n16 = Reshape[T=float, Tshape=int32](n2, n15)
- n17 = Div[T=int32](n14, n15)
- n18 = Tile[T=float, Tmultiples=int32](n16, n17)
- n22 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21:1)
- n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21)
- n23 = Reshape[T=float, Tshape=int32](n22, n19)
- n25 = Reshape[T=float, Tshape=int32](n24, n20)
-}
-)P";
- EXPECT_EQ(e2, DebugString(g.get()));
+ ExpandInlineFunctions(lib_.get(), g.get());
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
+ auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
+ auto grad0_zero = ops::Const(s.WithOpName("grad0/zero"), 0);
+ auto grad0_one = ops::Const(s.WithOpName("grad0/one"), 1);
+ auto func0 = ops::Identity(s.WithOpName("Func/_0"), x);
+ auto func1 = ops::Identity(s.WithOpName("Func/_1"), y);
+ auto func2 = ops::Identity(s.WithOpName("Func/_2"), dz);
+ auto grad0_z = ops::Add(s.WithOpName("grad0/z"), func0, func1);
+ auto grad0_r = ops::Rank(s.WithOpName("grad0/r"), grad0_z);
+ auto grad0_indices = ops::Range(s.WithOpName("grad0/indices"), grad0_zero,
+ grad0_r, grad0_one);
+ auto grad0_l = ops::Sum(s.WithOpName("grad0/l"), grad0_z, grad0_indices);
+
+ NameAttrList sum;
+ sum.set_name("Sum");
+ (*sum.mutable_attr())["T"].set_type(DT_FLOAT);
+ (*sum.mutable_attr())["Tidx"].set_type(DT_INT32);
+ (*sum.mutable_attr())["keep_dims"].set_b(false);
+ auto grad0_func1 = ops::SymbolicGradient(
+ s.WithOpName("grad0/Func/_1"),
+ std::initializer_list<Input>{grad0_z, grad0_indices, func2},
+ {DT_FLOAT, DT_INT32}, sum);
+
+ auto grad0_func2 = ops::ZerosLike(s.WithOpName("grad0/Func/_2"), grad0_r);
+
+ NameAttrList add;
+ add.set_name("Add");
+ (*add.mutable_attr())["T"].set_type(DT_FLOAT);
+ auto grad0_func3 = ops::SymbolicGradient(
+ s.WithOpName("grad0/Func/_3"),
+ std::initializer_list<Input>{func0, func1, grad0_func1[0]},
+ {DT_FLOAT, DT_FLOAT}, add);
+
+ auto func3 = ops::Identity(s.WithOpName("Func/_3"), grad0_func3[0]);
+ auto func4 = ops::Identity(s.WithOpName("Func/_4"), grad0_func3[1]);
+ auto dx = ops::Identity(s.WithOpName("dx"), func3);
+ auto dy = ops::Identity(s.WithOpName("dy"), func4);
+ auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
+ auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
+
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
+
+ OptimizeGraph(lib_.get(), &g);
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
+ auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
+ auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
+ auto grad0_zero = ops::Const(s.WithOpName("grad0/zero"), 0);
+ auto grad0_one = ops::Const(s.WithOpName("grad0/one"), 1);
+ auto grad0_z = ops::Add(s.WithOpName("grad0/z"), x, y);
+ auto grad0_r = ops::Rank(s.WithOpName("grad0/r"), grad0_z);
+ auto grad0_indices = ops::Range(s.WithOpName("grad0/indices"), grad0_zero,
+ grad0_r, grad0_one);
+ auto i_shape =
+ ops::Shape(s.WithOpName("grad0/Func/_1/i_shape"), grad0_indices);
+ auto stitch_val = ops::Fill(s.WithOpName("grad0/Func/_1/stitch_val1"),
+ i_shape, grad0_one);
+ auto x_shape = ops::Shape(s.WithOpName("grad0/Func/_1/x_shape"), grad0_z);
+ auto y_shape = ops::DynamicStitch(
+ s.WithOpName("grad0/Func/_1/y_shape"),
+ std::initializer_list<Input>{grad0_indices, grad0_indices},
+ std::initializer_list<Input>{x_shape, stitch_val});
+ auto dy_reshaped =
+ ops::Reshape(s.WithOpName("grad0/Func/_1/dy_reshaped"), dz, y_shape);
+ auto tile_scaling =
+ ops::Div(s.WithOpName("grad0/Func/_1/tile_scaling"), x_shape, y_shape);
+ auto func1_dx =
+ ops::Tile(s.WithOpName("grad0/Func/_1/dx"), dy_reshaped, tile_scaling);
+
+ auto sx = ops::Shape(s.WithOpName("grad0/Func/_3/sx"), x);
+ auto sy = ops::Shape(s.WithOpName("grad0/Func/_3/sy"), y);
+ auto rx = ops::internal::BroadcastGradientArgs(
+ s.WithOpName("grad0/Func/_3/rx"), sx, sy);
+ auto sum_gx =
+ ops::Sum(s.WithOpName("grad0/Func/_3/sum_gx"), func1_dx, rx.r0);
+ auto sum_gy =
+ ops::Sum(s.WithOpName("grad0/Func/_3/sum_gy"), func1_dx, rx.r1);
+ auto dx = ops::Reshape(s.WithOpName("grad0/Func/_3/dx"), sum_gx, sx);
+ auto dy = ops::Reshape(s.WithOpName("grad0/Func/_3/dy"), sum_gy, sy);
+
+ auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
+ auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
+
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+
+ GraphDef actual;
+ g->ToGraphDef(&actual);
+ TF_EXPECT_GRAPH_EQ(expected, actual);
+ }
}
namespace {
bool DoNothing(Graph* g) { return false; }
-string Optimize(const std::function<bool(Graph* g)>& pass,
- const FunctionDef& fdef) {
+GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
+ const FunctionDef& fdef) {
InstantiationResult result;
InstantiateAttrValueMap empty;
TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result));
- Graph* g = new Graph(OpRegistry::Global());
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
- TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g));
- pass(g);
- Graph* g1 = new Graph(OpRegistry::Global());
- CopyGraph(*g, g1);
- delete g;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g.get()));
+ pass(g.get());
+ std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
+ CopyGraph(*g, g1.get());
+ g = nullptr;
GraphDef gdef;
g1->ToGraphDef(&gdef);
- delete g1;
- return DebugString(gdef);
+ return gdef;
}
} // end namespace
@@ -700,21 +886,25 @@ TEST(OptimizationTest, RemoveDeadNodes) {
{{"keep_me"}, "RandomUniform", {"o"}, {{"T", T}, {"dtype", DT_FLOAT}}},
// y = Add<T>(a, o)
{{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
- const char* e0 = R"S(
-(x:int32) -> (y:int32) {
- o = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
- keep_me = RandomUniform[T=int32, dtype=float, seed2=0, seed=0](o)
- x1 = Add[T=int32](o, o)
- a = Square[T=int32](x)
- y = Add[T=int32](a, o)
- x2 = Mul[T=int32](a, x1)
- x3 = Mul[T=int32](x1, x2)
-}
-)S";
- EXPECT_EQ(Optimize(DoNothing, func), e0);
+
+ GraphDef expected;
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
+ auto o = ops::Const(s.WithOpName("o"), 1);
+ auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT);
+ auto x1 = ops::Add(s.WithOpName("x1"), o, o);
+ auto a = ops::Square(s.WithOpName("a"), x);
+ auto y = ops::Add(s.WithOpName("y"), a, o);
+ auto x2 = ops::Mul(s.WithOpName("x2"), a, x1);
+ auto x3 = ops::Mul(s.WithOpName("x3"), x1, x2);
+ auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+ }
+ TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
// TODO(zhifengc): Comes up another test case.
- EXPECT_EQ(Optimize(::tensorflow::RemoveDeadNodes, func), e0);
+ TF_EXPECT_GRAPH_EQ(expected, Optimize(::tensorflow::RemoveDeadNodes, func));
}
TEST(OptimizationTest, RemoveIdentityNodes_Ref) {
@@ -735,23 +925,19 @@ TEST(OptimizationTest, RemoveIdentityNodes_Ref) {
{{"v_read"}, "Identity", {"v"}, {{"T", T}}},
// returns v + v
{{"ret"}, "Add", {"v_read", "v_read"}, {{"T", T}}}});
- const char* e0 = R"S(
-() -> (ret:float) {
- v = VariableV2[container="", dtype=float, shape=[], shared_name=""]()
- v_read = Identity[T=float](v)
- ret = Add[T=float](v_read, v_read)
-}
-)S";
- EXPECT_EQ(Optimize(DoNothing, func), e0);
-
- const char* e1 = R"S(
-() -> (ret:float) {
- v = VariableV2[container="", dtype=float, shape=[], shared_name=""]()
- v_read = Identity[T=float](v)
- ret = Add[T=float](v_read, v_read)
-}
-)S";
- EXPECT_EQ(Optimize(::tensorflow::RemoveIdentityNodes, func), e1);
+
+ GraphDef expected;
+ {
+ Scope s = Scope::NewRootScope();
+ auto v = ops::Variable(s.WithOpName("v"), PartialTensorShape({}), DT_FLOAT);
+ auto v_read = ops::Identity(s.WithOpName("v_read"), v);
+ auto ret = ops::Add(s.WithOpName("ret"), v_read, v_read);
+ auto ret_retval = ops::_Retval(s.WithOpName("ret_RetVal"), ret, 0);
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+ }
+ TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
+ TF_EXPECT_GRAPH_EQ(expected,
+ Optimize(::tensorflow::RemoveIdentityNodes, func));
}
TEST(OptimizationTest, RemoveIdentityNodes) {
@@ -782,28 +968,38 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
{"x3"}},
// y = Add<T>(a, o)
{{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
- const char* e0 = R"S(
-(x:int32) -> (y:int32) {
- o = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
- a = Square[T=int32](x)
- y = Add[T=int32](a, o)
- x1 = Identity[T=int32](a)
- x2 = Identity[T=int32](x1)
- x3 = Identity[T=int32](x2)
- keep_me = RandomUniform[T=int32, dtype=float, seed2=0, seed=0](o) @ x3
-}
-)S";
- EXPECT_EQ(Optimize(DoNothing, func), e0);
-
- const char* e1 = R"S(
-(x:int32) -> (y:int32) {
- o = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
- a = Square[T=int32](x)
- y = Add[T=int32](a, o)
- keep_me = RandomUniform[T=int32, dtype=float, seed2=0, seed=0](o) @ a
-}
-)S";
- EXPECT_EQ(Optimize(::tensorflow::RemoveIdentityNodes, func), e1);
+
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
+ auto o = ops::Const(s.WithOpName("o"), 1);
+ auto a = ops::Square(s.WithOpName("a"), x);
+ auto y = ops::Add(s.WithOpName("y"), a, o);
+ auto x1 = ops::Identity(s.WithOpName("x1"), a);
+ auto x2 = ops::Identity(s.WithOpName("x2"), x1);
+ auto x3 = ops::Identity(s.WithOpName("x3"), x2);
+ auto keep_me = ops::RandomUniform(
+ s.WithOpName("keep_me").WithControlDependencies(x3), {o}, DT_FLOAT);
+ auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+ TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
+ }
+
+ {
+ Scope s = Scope::NewRootScope();
+ auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
+ auto o = ops::Const(s.WithOpName("o"), 1);
+ auto a = ops::Square(s.WithOpName("a"), x);
+ auto y = ops::Add(s.WithOpName("y"), a, o);
+ auto keep_me = ops::RandomUniform(
+ s.WithOpName("keep_me").WithControlDependencies(a), {o}, DT_FLOAT);
+ auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+ TF_EXPECT_GRAPH_EQ(expected,
+ Optimize(::tensorflow::RemoveIdentityNodes, func));
+ }
}
TEST(OptimizationTest, RemoveListArrayConverter) {
@@ -840,49 +1036,63 @@ TEST(OptimizationTest, RemoveListArrayConverter) {
// Return values
{{"o", "o:sum"}});
- const char* e0 = R"P(
-(i:float) -> (o:float) {
- zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
- s = Split[T=float, num_split=4](zero, i)
- a = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](s, s:1, s:2, s:3)
- r = Mul[T=float](a:2, a:3)
- l = Mul[T=float](a, a:1)
- x = _ListToArray[N=2, T=float, Tin={float, float}](l, r)
- o = AddN[N=2, T=float](x, x:1)
-}
-)P";
- EXPECT_EQ(Optimize(DoNothing, func), e0);
-
- const char* e1 = R"P(
-(i:float) -> (o:float) {
- zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
- s = Split[T=float, num_split=4](zero, i)
- r = Mul[T=float](Func/_2, Func/_3)
- l = Mul[T=float](Func/_0, Func/_1)
- o = AddN[N=2, T=float](Func/_4, Func/_5)
- Func/_0 = Identity[T=float](s)
- Func/_1 = Identity[T=float](s:1)
- Func/_2 = Identity[T=float](s:2)
- Func/_3 = Identity[T=float](s:3)
- Func/_4 = Identity[T=float](l)
- Func/_5 = Identity[T=float](r)
-}
-)P";
- EXPECT_EQ(Optimize(RemoveListArrayConverter, func), e1);
-
- const char* e2 = R"P(
-(i:float) -> (o:float) {
- zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
- s = Split[T=float, num_split=4](zero, i)
- r = Mul[T=float](s:2, s:3)
- l = Mul[T=float](s, s:1)
- o = AddN[N=2, T=float](l, r)
-}
-)P";
- auto remove_listarray_and_identity = [](Graph* g) {
- return RemoveListArrayConverter(g) && RemoveIdentityNodes(g);
- };
- EXPECT_EQ(Optimize(remove_listarray_and_identity, func), e2);
+ {
+ Scope scope = Scope::NewRootScope();
+ auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
+ auto zero = ops::Const(scope.WithOpName("zero"), 0);
+ auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
+ auto a = ops::_ArrayToList(scope.WithOpName("a"), s.output,
+ {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT});
+ auto r = ops::Mul(scope.WithOpName("r"), a[2], a[3]);
+ auto l = ops::Mul(scope.WithOpName("l"), a[0], a[1]);
+ auto x = ops::_ListToArray(scope.WithOpName("x"),
+ std::initializer_list<Input>{l, r}, DT_FLOAT, 2);
+ auto o = ops::AddN(scope.WithOpName("o"), x.output);
+ auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(scope.ToGraphDef(&expected));
+ TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
+ }
+
+ {
+ Scope scope = Scope::NewRootScope();
+ auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
+ auto zero = ops::Const(scope.WithOpName("zero"), 0);
+ auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
+ auto func_0 = ops::Identity(scope.WithOpName("Func/_0"), s[0]);
+ auto func_1 = ops::Identity(scope.WithOpName("Func/_1"), s[1]);
+ auto func_2 = ops::Identity(scope.WithOpName("Func/_2"), s[2]);
+ auto func_3 = ops::Identity(scope.WithOpName("Func/_3"), s[3]);
+ auto r = ops::Mul(scope.WithOpName("r"), func_2, func_3);
+ auto l = ops::Mul(scope.WithOpName("l"), func_0, func_1);
+ auto func_4 = ops::Identity(scope.WithOpName("Func/_4"), l);
+ auto func_5 = ops::Identity(scope.WithOpName("Func/_5"), r);
+ auto o = ops::AddN(scope.WithOpName("o"),
+ std::initializer_list<Input>{func_4, func_5});
+ auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(scope.ToGraphDef(&expected));
+ TF_EXPECT_GRAPH_EQ(expected, Optimize(RemoveListArrayConverter, func));
+ }
+
+ {
+ Scope scope = Scope::NewRootScope();
+ auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
+ auto zero = ops::Const(scope.WithOpName("zero"), 0);
+ auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
+ auto r = ops::Mul(scope.WithOpName("r"), s[2], s[3]);
+ auto l = ops::Mul(scope.WithOpName("l"), s[0], s[1]);
+ auto o =
+ ops::AddN(scope.WithOpName("o"), std::initializer_list<Input>{l, r});
+ auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(scope.ToGraphDef(&expected));
+
+ auto remove_listarray_and_identity = [](Graph* g) {
+ return RemoveListArrayConverter(g) && RemoveIdentityNodes(g);
+ };
+ TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
+ }
}
TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
@@ -911,33 +1121,47 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
{"x"}}},
{{"o", "o:sum"}});
- const char* e0 = R"P(
-(i:float) -> (o:float) {
- dummy = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
- x = _ListToArray[N=2, T=float, Tin={float, float}](i, i) @ dummy
- o = AddN[N=2, T=float](x, x:1) @ x
-}
-)P";
- EXPECT_EQ(Optimize(DoNothing, func), e0);
-
- const char* e1 = R"P(
-(i:float) -> (o:float) {
- dummy = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
- o = AddN[N=2, T=float](Func/_0, Func/_1) @ Func/_3
- Func/_0 = Identity[T=float](i) @ Func/_2
- Func/_1 = Identity[T=float](i) @ Func/_2
- Func/_2 = NoOp() @ dummy
- Func/_3 = NoOp() @ Func/_0, Func/_1
-}
-)P";
- EXPECT_EQ(Optimize(RemoveListArrayConverter, func), e1);
+ {
+ Scope s = Scope::NewRootScope();
+ auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
+ auto dummy = ops::Const(s.WithOpName("dummy"), 0);
+ auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy),
+ std::initializer_list<Input>{i, i}, DT_FLOAT, 2);
+ auto o =
+ ops::AddN(s.WithOpName("o").WithControlDependencies({x.output[0].op()}),
+ x.output);
+ auto o_ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
+ GraphDef expected;
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+ TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
+ }
+
+ GraphDef expected;
+ {
+ Scope s = Scope::NewRootScope();
+ auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
+ auto dummy = ops::Const(s.WithOpName("dummy"), 0);
+ auto func_2 =
+ ops::NoOp(s.WithOpName("Func/_2").WithControlDependencies(dummy));
+ auto func_0 = ops::Identity(
+ s.WithOpName("Func/_0").WithControlDependencies({func_2}), i);
+ auto func_1 = ops::Identity(
+ s.WithOpName("Func/_1").WithControlDependencies({func_2}), i);
+ auto func_3 = ops::NoOp(s.WithOpName("Func/_3").WithControlDependencies(
+ {func_0.output.op(), func_1.output.op()}));
+ auto o = ops::AddN(s.WithOpName("o").WithControlDependencies({func_3}),
+ std::initializer_list<Input>{func_0, func_1});
+ auto o_ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
+ TF_ASSERT_OK(s.ToGraphDef(&expected));
+ }
+ TF_EXPECT_GRAPH_EQ(expected, Optimize(RemoveListArrayConverter, func));
auto remove_listarray_and_identity = [](Graph* g) {
return RemoveListArrayConverter(g) && RemoveIdentityNodes(g);
};
// NOTE: We are not removing Identity nodes with any control
// dependencies yet.
- EXPECT_EQ(Optimize(remove_listarray_and_identity, func), e1);
+ TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
}
} // end namespace tensorflow