diff options
Diffstat (limited to 'tensorflow/core/graph/testlib.cc')
-rw-r--r-- | tensorflow/core/graph/testlib.cc | 299 |
1 files changed, 299 insertions, 0 deletions
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc new file mode 100644 index 0000000000..e49d5e819a --- /dev/null +++ b/tensorflow/core/graph/testlib.cc @@ -0,0 +1,299 @@ +#include "tensorflow/core/graph/testlib.h" + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace test { +namespace graph { + +Node* Send(Graph* g, Node* input, const string& tensor, const string& sender, + const uint64 sender_incarnation, const string& receiver) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send") + .Input(input, 0) + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast<int64>(sender_incarnation)) + .Attr("recv_device", receiver) + .Finalize(g, &ret)); + return ret; +} + +Node* Recv(Graph* g, const string& tensor, const string& type, + const string& sender, const uint64 sender_incarnation, + const string& receiver) { + Node* ret; + DataType dtype; + CHECK(DataTypeFromString(type, &dtype)); + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv") + .Attr("tensor_type", dtype) + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast<int64>(sender_incarnation)) + .Attr("recv_device", receiver) + .Finalize(g, &ret)); + return ret; +} + +Node* Constant(Graph* g, const Tensor& tensor) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const") + .Attr("dtype", tensor.dtype()) + .Attr("value", tensor) + .Finalize(g, &ret)); + return ret; +} + +Node* Constant(Graph* g, const Tensor& tensor, const string& name) { + Node* ret; + TF_CHECK_OK(NodeBuilder(name, "Const") + .Attr("dtype", tensor.dtype()) + .Attr("value", tensor) + .Finalize(g, &ret)); + return ret; +} + +Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable") + .Attr("dtype", dtype) + .Attr("shape", shape) + .Finalize(g, &ret)); + return ret; +} + +Node* Assign(Graph* g, Node* var, Node* val) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign") + .Input(var) + .Input(val) + .Attr("use_locking", true) + .Finalize(g, &ret)); + return ret; +} + +Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, + bool keep_dims) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce) + .Input(data) + .Input(axes) + .Attr("keep_dims", keep_dims) + .Finalize(g, &ret)); + return ret; +} + +Node* QuantizeToUINT8(Graph* g, Node* data) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize") + .Input(data) + .Attr("T", DT_QUINT8) + .Attr("max_range", 1.0f) + .Attr("min_range", -1.0f) + .Finalize(g, &ret)); + return ret; +} + +Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, + bool transpose_b) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul") + .Input(in0) + .Input(in1) + .Attr("transpose_a", transpose_a) + .Attr("transpose_b", transpose_b) + .Finalize(g, &ret)); + return ret; +} + +Node* RandomNumberGenerator(const string& op, Graph* g, Node* input, + DataType dtype) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), op) + .Input(input) + .Attr("dtype", dtype) + .Attr("seed", 0) + .Finalize(g, &ret)); + return ret; +} + +Node* RandomUniform(Graph* g, Node* input, DataType dtype) { + return RandomNumberGenerator("RandomUniform", g, input, dtype); +} + +Node* RandomGaussian(Graph* g, Node* input, DataType dtype) { + return RandomNumberGenerator("RandomStandardNormal", g, input, dtype); +} + +Node* RandomParameters(Graph* g, Node* input, DataType dtype) { + return RandomNumberGenerator("RandomParameters", g, input, dtype); +} + +Node* Unary(Graph* g, const string& func, Node* input, int index) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), func).Input(input, index).Finalize(g, &ret)); + return ret; +} + +Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), func) + .Input(in0) + .Input(in1) + .Finalize(g, &ret)); + return ret; +} + +Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) { + Node* ret; + auto b = NodeBuilder(g->NewName("n"), func); + for (Node* n : ins) b = b.Input(n); + TF_CHECK_OK(b.Finalize(g, &ret)); + return ret; +} + +Node* Identity(Graph* g, Node* input, int index) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity") + .Input(input, index) + .Finalize(g, &ret)); + return ret; +} + +Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); } + +Node* Error(Graph* g, Node* input, const string& errmsg) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error") + .Input(input) + .Attr("message", errmsg) + .Finalize(g, &ret)); + return ret; +} + +Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) { + DCHECK(out_type != invalid_type); + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType") + .Attr("TIn", out_type) + .Attr("TOut", invalid_type) + .Finalize(g, &ret)); + return ret; +} + +Node* Delay(Graph* g, Node* input, Microseconds delay_micros) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay") + .Input(input) + .Attr("micros", delay_micros.value()) + .Finalize(g, &ret)); + return ret; +} + +Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp") + .ControlInputs(control_inputs) + .Finalize(g, &ret)); + return ret; +} + +Node* Switch(Graph* g, Node* in0, Node* in1) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch") + .Input(in0) + .Input(in1) + .Finalize(g, &ret)); + return ret; +} + +Node* Enter(Graph* g, Node* input, const string& frame_name) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter") + .Input(input) + .Attr("frame_name", frame_name) + .Finalize(g, &ret)); + return ret; +} + +Node* Exit(Graph* g, Node* input) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret)); + return ret; +} + +Node* Merge(Graph* g, Node* in0, Node* in1) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge") + .Input({in0, in1}) + .Finalize(g, &ret)); + return ret; +} + +Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) { + std::vector<NodeBuilder::NodeOut> inputs; + inputs.reserve(remaining_in.size() + 1); + inputs.emplace_back(in0); + for (const string& in_name : remaining_in) { + inputs.emplace_back(in_name, 0, inputs[0].dt); + } + + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret)); + return ret; +} + +Node* Next(Graph* g, const string& name, Node* input) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret)); + return ret; +} + +Node* LoopCond(Graph* g, Node* input) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret)); + return ret; +} + +Node* Less(Graph* g, Node* in0, Node* in1) { + return Binary(g, "Less", in0, in1); +} + +Node* Select(Graph* g, Node* c, Node* inx, Node* iny) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select") + .Input(c) + .Input(inx) + .Input(iny) + .Finalize(g, &ret)); + return ret; +} + +Node* Cast(Graph* g, Node* in, DataType dst) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast") + .Input(in) + .Attr("DstT", dst) + .Finalize(g, &ret)); + return ret; +} + +void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } + +} // end namespace graph +} // end namespace test +} // end namespace tensorflow |