aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/testlib.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/testlib.cc')
-rw-r--r--tensorflow/core/graph/testlib.cc299
1 files changed, 299 insertions, 0 deletions
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
new file mode 100644
index 0000000000..e49d5e819a
--- /dev/null
+++ b/tensorflow/core/graph/testlib.cc
@@ -0,0 +1,299 @@
+#include "tensorflow/core/graph/testlib.h"
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace test {
+namespace graph {
+
+Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
+ const uint64 sender_incarnation, const string& receiver) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send")
+ .Input(input, 0)
+ .Attr("tensor_name", tensor)
+ .Attr("send_device", sender)
+ .Attr("send_device_incarnation",
+ static_cast<int64>(sender_incarnation))
+ .Attr("recv_device", receiver)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Recv(Graph* g, const string& tensor, const string& type,
+ const string& sender, const uint64 sender_incarnation,
+ const string& receiver) {
+ Node* ret;
+ DataType dtype;
+ CHECK(DataTypeFromString(type, &dtype));
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv")
+ .Attr("tensor_type", dtype)
+ .Attr("tensor_name", tensor)
+ .Attr("send_device", sender)
+ .Attr("send_device_incarnation",
+ static_cast<int64>(sender_incarnation))
+ .Attr("recv_device", receiver)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Constant(Graph* g, const Tensor& tensor) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const")
+ .Attr("dtype", tensor.dtype())
+ .Attr("value", tensor)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Constant(Graph* g, const Tensor& tensor, const string& name) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(name, "Const")
+ .Attr("dtype", tensor.dtype())
+ .Attr("value", tensor)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable")
+ .Attr("dtype", dtype)
+ .Attr("shape", shape)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Assign(Graph* g, Node* var, Node* val) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign")
+ .Input(var)
+ .Input(val)
+ .Attr("use_locking", true)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
+ bool keep_dims) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce)
+ .Input(data)
+ .Input(axes)
+ .Attr("keep_dims", keep_dims)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* QuantizeToUINT8(Graph* g, Node* data) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize")
+ .Input(data)
+ .Attr("T", DT_QUINT8)
+ .Attr("max_range", 1.0f)
+ .Attr("min_range", -1.0f)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
+ bool transpose_b) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul")
+ .Input(in0)
+ .Input(in1)
+ .Attr("transpose_a", transpose_a)
+ .Attr("transpose_b", transpose_b)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
+ DataType dtype) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), op)
+ .Input(input)
+ .Attr("dtype", dtype)
+ .Attr("seed", 0)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* RandomUniform(Graph* g, Node* input, DataType dtype) {
+ return RandomNumberGenerator("RandomUniform", g, input, dtype);
+}
+
+Node* RandomGaussian(Graph* g, Node* input, DataType dtype) {
+ return RandomNumberGenerator("RandomStandardNormal", g, input, dtype);
+}
+
+Node* RandomParameters(Graph* g, Node* input, DataType dtype) {
+ return RandomNumberGenerator("RandomParameters", g, input, dtype);
+}
+
+Node* Unary(Graph* g, const string& func, Node* input, int index) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), func).Input(input, index).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), func)
+ .Input(in0)
+ .Input(in1)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
+ Node* ret;
+ auto b = NodeBuilder(g->NewName("n"), func);
+ for (Node* n : ins) b = b.Input(n);
+ TF_CHECK_OK(b.Finalize(g, &ret));
+ return ret;
+}
+
+Node* Identity(Graph* g, Node* input, int index) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity")
+ .Input(input, index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
+
+Node* Error(Graph* g, Node* input, const string& errmsg) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
+ .Input(input)
+ .Attr("message", errmsg)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) {
+ DCHECK(out_type != invalid_type);
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType")
+ .Attr("TIn", out_type)
+ .Attr("TOut", invalid_type)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Delay(Graph* g, Node* input, Microseconds delay_micros) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay")
+ .Input(input)
+ .Attr("micros", delay_micros.value())
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp")
+ .ControlInputs(control_inputs)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Switch(Graph* g, Node* in0, Node* in1) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch")
+ .Input(in0)
+ .Input(in1)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Enter(Graph* g, Node* input, const string& frame_name) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter")
+ .Input(input)
+ .Attr("frame_name", frame_name)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Exit(Graph* g, Node* input) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Merge(Graph* g, Node* in0, Node* in1) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge")
+ .Input({in0, in1})
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) {
+ std::vector<NodeBuilder::NodeOut> inputs;
+ inputs.reserve(remaining_in.size() + 1);
+ inputs.emplace_back(in0);
+ for (const string& in_name : remaining_in) {
+ inputs.emplace_back(in_name, 0, inputs[0].dt);
+ }
+
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Next(Graph* g, const string& name, Node* input) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret));
+ return ret;
+}
+
+Node* LoopCond(Graph* g, Node* input) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Less(Graph* g, Node* in0, Node* in1) {
+ return Binary(g, "Less", in0, in1);
+}
+
+Node* Select(Graph* g, Node* c, Node* inx, Node* iny) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select")
+ .Input(c)
+ .Input(inx)
+ .Input(iny)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Cast(Graph* g, Node* in, DataType dst) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast")
+ .Input(in)
+ .Attr("DstT", dst)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
+
+} // end namespace graph
+} // end namespace test
+} // end namespace tensorflow