aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/fake_input.cc
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
commitf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch)
treeef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/framework/fake_input.cc
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108
Diffstat (limited to 'tensorflow/core/framework/fake_input.cc')
-rw-r--r--tensorflow/core/framework/fake_input.cc214
1 files changed, 214 insertions, 0 deletions
diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc
new file mode 100644
index 0000000000..493c35e05f
--- /dev/null
+++ b/tensorflow/core/framework/fake_input.cc
@@ -0,0 +1,214 @@
+#include "tensorflow/core/framework/fake_input.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace {
+
+class FakeInputImpl {
+ public:
+ FakeInputImpl(const OpDef* op_def, int in_index, const NodeDef* node_def,
+ NodeDefBuilder* builder);
+ void SetN(int n);
+ void SetDataType(DataType dt);
+ void SetTypeList(DataTypeSlice dts);
+ Status AddInputToBuilder();
+
+ private:
+ static string FakeNodeName(int in_index);
+ Status GetN(int* n) const;
+ Status GetDataType(DataType* dt) const;
+ void NSources(int n, DataType dt) const;
+ void SourceList(DataTypeSlice dts) const;
+
+ const OpDef* const op_def_;
+ const OpDef::ArgDef* const arg_;
+ const string in_node_;
+ const NodeDef* const node_def_;
+ NodeDefBuilder* const builder_;
+
+ bool n_specified_;
+ int n_;
+ bool dt_specified_;
+ DataType dt_;
+ bool dts_specified_;
+ DataTypeSlice dts_;
+};
+
+FakeInputImpl::FakeInputImpl(const OpDef* op_def, int in_index,
+ const NodeDef* node_def, NodeDefBuilder* builder)
+ : op_def_(op_def),
+ arg_(&op_def->input_arg(in_index)),
+ in_node_(FakeNodeName(in_index)),
+ node_def_(node_def),
+ builder_(builder),
+ n_specified_(false),
+ dt_specified_(false),
+ dts_specified_(false) {}
+
+void FakeInputImpl::SetN(int n) {
+ n_specified_ = true;
+ n_ = n;
+}
+
+void FakeInputImpl::SetDataType(DataType dt) {
+ dt_specified_ = true;
+ dt_ = dt;
+}
+
+void FakeInputImpl::SetTypeList(DataTypeSlice dts) {
+ dts_specified_ = true;
+ dts_ = dts;
+}
+
+Status FakeInputImpl::AddInputToBuilder() {
+ if (dts_specified_) {
+ SourceList(dts_);
+
+ } else if (n_specified_ || !arg_->number_attr().empty()) {
+ int n;
+ TF_RETURN_IF_ERROR(GetN(&n));
+
+ DataType dt;
+ if (n > 0) {
+ TF_RETURN_IF_ERROR(GetDataType(&dt));
+ } else {
+ dt = DT_FLOAT;
+ }
+
+ NSources(n, dt);
+ } else {
+ if (!dt_specified_ && !arg_->type_list_attr().empty()) {
+ DataTypeVector dts;
+ Status status =
+ GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts);
+ if (!status.ok()) {
+ return errors::InvalidArgument(
+ "Could not infer list of types for input '", arg_->name(), "': ",
+ status.error_message());
+ }
+ SourceList(dts);
+ return Status::OK();
+ }
+
+ DataType dt;
+ TF_RETURN_IF_ERROR(GetDataType(&dt));
+ builder_->Input(in_node_, 0, dt);
+ }
+ return Status::OK();
+}
+
+// static
+string FakeInputImpl::FakeNodeName(int in_index) {
+ char c = 'a' + (in_index % 26);
+ return string(&c, 1);
+}
+
+Status FakeInputImpl::GetN(int* n) const {
+ if (n_specified_) {
+ *n = n_;
+ } else {
+ Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n);
+ if (!status.ok()) {
+ return errors::InvalidArgument("Could not infer length of input '",
+ arg_->name(), "': ",
+ status.error_message());
+ }
+ }
+ return Status::OK();
+}
+
+Status FakeInputImpl::GetDataType(DataType* dt) const {
+ if (dt_specified_) {
+ *dt = dt_;
+ } else if (arg_->type() != DT_INVALID) {
+ *dt = arg_->type();
+ } else if (!arg_->type_attr().empty()) {
+ Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt);
+ if (!status.ok()) {
+ return errors::InvalidArgument("Could not infer type for input '",
+ arg_->name(), "': ",
+ status.error_message());
+ }
+ } else {
+ return errors::InvalidArgument("No type or type_attr field in arg '",
+ arg_->name(), "'");
+ }
+ return Status::OK();
+}
+
+void FakeInputImpl::NSources(int n, DataType dt) const {
+ std::vector<NodeDefBuilder::NodeOut> srcs;
+ srcs.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ srcs.emplace_back(in_node_, i, dt);
+ }
+ builder_->Input(srcs);
+}
+
+void FakeInputImpl::SourceList(DataTypeSlice dts) const {
+ std::vector<NodeDefBuilder::NodeOut> srcs;
+ srcs.reserve(dts.size());
+ for (size_t i = 0; i < dts.size(); ++i) {
+ srcs.emplace_back(in_node_, i, dts[i]);
+ }
+ builder_->Input(srcs);
+}
+
+} // namespace
+
+// Public interface ------------------------------------------------------------
+
+FakeInputFunctor FakeInput() {
+ return [](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(DataType dt) {
+ return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetDataType(dt);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(int n) {
+ return [n](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetN(n);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(int n, DataType dt) {
+ return [n, dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetN(n);
+ impl.SetDataType(dt);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(DataTypeSlice dts) {
+ // Make a copy to ensure the data will still be around when the lambda is
+ // called.
+ DataTypeVector dtv(dts.begin(), dts.end());
+ return [dtv](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetTypeList(dtv);
+ return impl.AddInputToBuilder();
+ };
+}
+
+} // namespace tensorflow