diff options
author | 2018-08-30 12:13:54 -0700 | |
---|---|---|
committer | 2018-08-30 12:17:57 -0700 | |
commit | f9bd3bc45a57d218c7fe3e970c4f953da8136dc6 (patch) | |
tree | b488a89c5f066d2622334be1a7bcb0675d1b8759 /tensorflow/core/framework | |
parent | c44566cdd632e4a9b030244cbc36965cb0ee21c0 (diff) |
Add lowering pass for functional While op.
This will allow the functional tf.while_loop proposed in https://github.com/tensorflow/community/pull/13 to achieve feature parity with the current implementation.
Lowering is performed only when the "_lower_using_switch_merge" attr is set to True.
PiperOrigin-RevId: 210956432
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/function_testlib.cc | 56 | ||||
-rw-r--r-- | tensorflow/core/framework/function_testlib.h | 9 |
2 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index 6e38256ba8..46b169dddc 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -219,6 +219,62 @@ FunctionDef InvalidControlFlow() { {{"o", "add:z"}}); } +FunctionDef LessThanOrEqualToN(int64 N) { + const Tensor kN = test::AsScalar<int64>(N); + return FDH::Define( + // Name + "LessThanOrEqualToN", + // Args + {"x: T"}, + // Return values + {"z: bool"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}}, + {{"y"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"z"}, "LessEqual", {"x", "y"}, {{"T", "$T"}}}, + }); +} + +FunctionDef XPlusOneXTimesY() { + const Tensor kOne = test::AsScalar<int64>(1); + return FDH::Define( + // Name + "XPlusOneXTimesY", + // Args + {"x: T", "y: T"}, + // Return values + {"s: T", "t: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_INT64}}}, + {{"increment"}, "Cast", {"one"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"s"}, "Add", {"x", "increment"}, {{"T", "$T"}}}, + {{"t"}, "Mul", {"x", "y"}, {{"T", "$T"}}}}); +} + +FunctionDef XYXLessThanOrEqualToN(int64 N) { + const Tensor kN = test::AsScalar<int64>(N); + return FDH::Define( + // Name + "XYXLessThanOrEqualToN", + // Args + {"x: T", "y: T"}, + // Return values + {"z: bool"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}}, + {{"N1"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"z"}, "LessEqual", {"x", "N1"}, {{"T", "$T"}}}, + }); +} + void FunctionTestSchedClosure(std::function<void()> fn) { static thread::ThreadPool* w = new thread::ThreadPool(Env::Default(), "Test", 8); diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index af08d296b2..6d6476b936 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -87,6 +87,15 @@ FunctionDef Swap(); // Contains malformed control flow which can't be run by the executor. FunctionDef InvalidControlFlow(); +// x:T -> x <= N. +FunctionDef LessThanOrEqualToN(int64 N); + +// x:T, y:T -> x+1, x*y +FunctionDef XPlusOneXTimesY(); + +// x:T, y:T -> x <= N +FunctionDef XYXLessThanOrEqualToN(int64 N); + void FunctionTestSchedClosure(std::function<void()> fn); } // end namespace function |