aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-08-30 12:13:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 12:17:57 -0700
commitf9bd3bc45a57d218c7fe3e970c4f953da8136dc6 (patch)
treeb488a89c5f066d2622334be1a7bcb0675d1b8759 /tensorflow/core/framework
parentc44566cdd632e4a9b030244cbc36965cb0ee21c0 (diff)
Add lowering pass for functional While op.
This will allow the functional tf.while_loop proposed in https://github.com/tensorflow/community/pull/13 to achieve feature parity with the current implementation. Lowering is performed only when the "_lower_using_switch_merge" attr is set to True. PiperOrigin-RevId: 210956432
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/function_testlib.cc56
-rw-r--r--tensorflow/core/framework/function_testlib.h9
2 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index 6e38256ba8..46b169dddc 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -219,6 +219,62 @@ FunctionDef InvalidControlFlow() {
{{"o", "add:z"}});
}
+FunctionDef LessThanOrEqualToN(int64 N) {
+ const Tensor kN = test::AsScalar<int64>(N);
+ return FDH::Define(
+ // Name
+ "LessThanOrEqualToN",
+ // Args
+ {"x: T"},
+ // Return values
+ {"z: bool"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
+ {{"y"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"z"}, "LessEqual", {"x", "y"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef XPlusOneXTimesY() {
+ const Tensor kOne = test::AsScalar<int64>(1);
+ return FDH::Define(
+ // Name
+ "XPlusOneXTimesY",
+ // Args
+ {"x: T", "y: T"},
+ // Return values
+ {"s: T", "t: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_INT64}}},
+ {{"increment"}, "Cast", {"one"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"s"}, "Add", {"x", "increment"}, {{"T", "$T"}}},
+ {{"t"}, "Mul", {"x", "y"}, {{"T", "$T"}}}});
+}
+
+FunctionDef XYXLessThanOrEqualToN(int64 N) {
+ const Tensor kN = test::AsScalar<int64>(N);
+ return FDH::Define(
+ // Name
+ "XYXLessThanOrEqualToN",
+ // Args
+ {"x: T", "y: T"},
+ // Return values
+ {"z: bool"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
+ {{"N1"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"z"}, "LessEqual", {"x", "N1"}, {{"T", "$T"}}},
+ });
+}
+
void FunctionTestSchedClosure(std::function<void()> fn) {
static thread::ThreadPool* w =
new thread::ThreadPool(Env::Default(), "Test", 8);
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index af08d296b2..6d6476b936 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -87,6 +87,15 @@ FunctionDef Swap();
// Contains malformed control flow which can't be run by the executor.
FunctionDef InvalidControlFlow();
+// x:T -> x <= N.
+FunctionDef LessThanOrEqualToN(int64 N);
+
+// x:T, y:T -> x+1, x*y
+FunctionDef XPlusOneXTimesY();
+
+// x:T, y:T -> x <= N
+FunctionDef XYXLessThanOrEqualToN(int64 N);
+
void FunctionTestSchedClosure(std::function<void()> fn);
} // end namespace function