diff options
author | Yao Zhang <yaozhang@google.com> | 2017-03-08 13:31:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-08 13:51:36 -0800 |
commit | 3e1ad1c84a48a320cb3de8f9ddd53d3923c90564 (patch) | |
tree | c74fc865bd2b26bd23df25c8869a46f8c8ddf6c6 /tensorflow/core/grappler/grappler_item.cc | |
parent | e62d0730d376bcad61381c99cc2899e013922171 (diff) |
Add a function to initialize a tensor and a utility function to add a prefix to the node name.
Change: 149576194
Diffstat (limited to 'tensorflow/core/grappler/grappler_item.cc')
-rw-r--r-- | tensorflow/core/grappler/grappler_item.cc | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 907ef2096a..c4691ad58f 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -31,6 +31,28 @@ limitations under the License. namespace tensorflow { namespace grappler { +namespace { +void InitializeTensor(DataType type, Tensor* tensor) { + const int period = 7; + if (type == DT_FLOAT) { + auto flat = tensor->flat<float>(); + // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ... + for (int i = 0; i < flat.size(); i++) { + flat(i) = static_cast<float>(i % period) / 10.0f; + } + } else if (type == DT_INT64) { + auto flat = tensor->flat<int64>(); + // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ... + for (int i = 0; i < flat.size(); i++) { + flat(i) = i % period; + } + } else { + memset(const_cast<char*>(tensor->tensor_data().data()), 0, + tensor->tensor_data().size()); + } +} +} // namespace + // static std::unique_ptr<GrapplerItem> GrapplerItem::FromMetaGraphDef( const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) { @@ -125,11 +147,7 @@ std::unique_ptr<GrapplerItem> GrapplerItem::FromMetaGraphDef( } } Tensor fake_input(type, shape); - // TODO(bsteiner): figure out a better way to initialize the feeds, for - // example by recording a sample of the fed inputs in mldash when running - // the graph. - memset(const_cast<char*>(fake_input.tensor_data().data()), 0, - fake_input.tensor_data().size()); + InitializeTensor(type, &fake_input); new_item->feed.emplace_back(node.name(), fake_input); } |