aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/memory_types.cc
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2016-04-02 11:55:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-02 13:02:16 -0700
commit0cd025b21cbf86efe4f8e09ad2677d15f45a6ff0 (patch)
tree465b64e05d87dac3ed590848d9ee254f3d8b4099 /tensorflow/core/common_runtime/memory_types.cc
parent4504be5df9383f744cae10b2dcde8f7ecf1fae7b (diff)
Enable constant folding in L0 optimization level.
Change: 118861866
Diffstat (limited to 'tensorflow/core/common_runtime/memory_types.cc')
-rw-r--r--tensorflow/core/common_runtime/memory_types.cc15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/memory_types.cc b/tensorflow/core/common_runtime/memory_types.cc
index d8585859b8..957a124934 100644
--- a/tensorflow/core/common_runtime/memory_types.cc
+++ b/tensorflow/core/common_runtime/memory_types.cc
@@ -82,4 +82,19 @@ Status ValidateMemoryTypes(DeviceType device_type, const Graph* g) {
return Status::OK();
}
+Status MemoryTypeForOutput(DeviceType device_type, const Graph* g,
+ const Node* n, int index, MemoryType* memory_type) {
+ MemoryTypeVector inp_mvec;
+ MemoryTypeVector out_mvec;
+ TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type, n->def(),
+ &inp_mvec, &out_mvec));
+ if (out_mvec.size() <= index) {
+ return errors::Internal("Trying to get the memory type for ", index,
+ "'th output of node ", n->DebugString(),
+ " that has only ", out_mvec.size(), " outputs");
+ }
+ *memory_type = out_mvec[index];
+ return Status::OK();
+}
+
} // end namespace tensorflow