diff options
Diffstat (limited to 'tensorflow/core/framework/variant.cc')
-rw-r--r-- | tensorflow/core/framework/variant.cc | 25 |
1 files changed, 13 insertions, 12 deletions
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc index 5a507804b0..d43e3c72ec 100644 --- a/tensorflow/core/framework/variant.cc +++ b/tensorflow/core/framework/variant.cc @@ -23,11 +23,11 @@ limitations under the License. namespace tensorflow { -bool Variant::TryDecode(Variant* out) const { - const VariantTensorDataProto* p = get<VariantTensorDataProto>(); - if (p == nullptr) return false; - VariantTensorData data(*p); - return out->Decode(data); +bool Variant::Decode(VariantTensorData data) { + if (!is_empty()) { + return value_->Decode(std::move(data)); + } + return true; } template <> @@ -54,13 +54,12 @@ string TypeNameVariant(const VariantTensorDataProto& value) { template <> void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data) { - data->FromProto(value); + data->FromConstProto(value); } template <> -bool DecodeVariant(const VariantTensorData& data, - VariantTensorDataProto* value) { - data.ToProto(value); +bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) { + data->ToProto(value); return true; } @@ -70,8 +69,8 @@ void EncodeVariant(const VariantTensorDataProto& value, string* buf) { } template <> -bool DecodeVariant(const string& buf, VariantTensorDataProto* value) { - return value->ParseFromString(buf); +bool DecodeVariant(string* buf, VariantTensorDataProto* value) { + return value->ParseFromString(*buf); } void EncodeVariantList(const Variant* variant_array, int64 n, @@ -93,8 +92,10 @@ bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d, if (variant_array[i].is_empty()) { variant_array[i] = VariantTensorDataProto(); } + // TODO(ebrevdo): Replace with StringPiece? Any way to make this a + // zero-copy operation that keeps a reference to the data in d? string str(d->Data(sizes[i]), sizes[i]); - if (!variant_array[i].Decode(str)) return false; + if (!variant_array[i].Decode(std::move(str))) return false; if (!DecodeUnaryVariant(&variant_array[i])) { LOG(ERROR) << "Could not decode variant with type_name: \"" << variant_array[i].TypeName() |