aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/grappler_item.cc
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-03-08 13:31:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-08 13:51:36 -0800
commit3e1ad1c84a48a320cb3de8f9ddd53d3923c90564 (patch)
treec74fc865bd2b26bd23df25c8869a46f8c8ddf6c6 /tensorflow/core/grappler/grappler_item.cc
parente62d0730d376bcad61381c99cc2899e013922171 (diff)
Add a function to initialize a tensor and a utility function to add a prefix to the node name.
Change: 149576194
Diffstat (limited to 'tensorflow/core/grappler/grappler_item.cc')
-rw-r--r--tensorflow/core/grappler/grappler_item.cc28
1 files changed, 23 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index 907ef2096a..c4691ad58f 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -31,6 +31,28 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+namespace {
+void InitializeTensor(DataType type, Tensor* tensor) {
+ const int period = 7;
+ if (type == DT_FLOAT) {
+ auto flat = tensor->flat<float>();
+ // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ...
+ for (int i = 0; i < flat.size(); i++) {
+ flat(i) = static_cast<float>(i % period) / 10.0f;
+ }
+ } else if (type == DT_INT64) {
+ auto flat = tensor->flat<int64>();
+ // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ...
+ for (int i = 0; i < flat.size(); i++) {
+ flat(i) = i % period;
+ }
+ } else {
+ memset(const_cast<char*>(tensor->tensor_data().data()), 0,
+ tensor->tensor_data().size());
+ }
+}
+} // namespace
+
// static
std::unique_ptr<GrapplerItem> GrapplerItem::FromMetaGraphDef(
const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
@@ -125,11 +147,7 @@ std::unique_ptr<GrapplerItem> GrapplerItem::FromMetaGraphDef(
}
}
Tensor fake_input(type, shape);
- // TODO(bsteiner): figure out a better way to initialize the feeds, for
- // example by recording a sample of the fed inputs in mldash when running
- // the graph.
- memset(const_cast<char*>(fake_input.tensor_data().data()), 0,
- fake_input.tensor_data().size());
+ InitializeTensor(type, &fake_input);
new_item->feed.emplace_back(node.name(), fake_input);
}