diff options
Diffstat (limited to 'tensorflow/core/framework/memory_types.cc')
-rw-r--r-- | tensorflow/core/framework/memory_types.cc | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 270118bb67..6dff6fe654 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -60,13 +60,18 @@ void MemoryTypesHelper(const NameRangeMap& name_map, host_memory_args->resize(keep); } +bool IsFunctionCallOp(const string& op_type) { + return op_type == "SymbolicGradient" || op_type == "PartitionedCall" || + op_type == "StatefulPartitionedCall"; +} + +} // namespace + MemoryType MTypeFromDType(const DataType dtype) { return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY : DEVICE_MEMORY; } -} // namespace - Status MemoryTypesForNode(const OpRegistryInterface* op_registry, const DeviceType& device_type, const NodeDef& ndef, MemoryTypeVector* inp_mtypes, @@ -94,7 +99,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, // TODO(zhifengc,phawkins): We should do type inference over function bodies // to derive the correct input/output memory types. We should also split // host-memory and non host-memory arguments into separate type lists. - if (!status.ok() || ndef.op() == "SymbolicGradient") { + if (!status.ok() || IsFunctionCallOp(ndef.op())) { for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t)); for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t)); return Status::OK(); |