aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-09-11 10:41:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 10:51:01 -0700
commit36e1a5ea5ba2dd5eaa7f4cfc84a61f8ce3ea20e1 (patch)
tree4f1671f78f5971b02dc2af66f57eabbf01005112 /tensorflow/core/framework
parent36d7b12357df667dcd427c070e21779ed83f4ec9 (diff)
[TF] Variant improvements.
1. Change Variant Decode to accept VariantTensorData (non-ref). This should allow some optimization in the future. In the meantime it means removing the variant.h include from tensor.h, since variant_encode_decode.h now relies on tensor.h and variant.h now relies on that. It also means we found a bunch of places where tensor.proto.h, variant.h, and mutex.h were being imported through tensor.h (along with a bunch of other crap); so now we directly import them in order to compile. 2. Move Variant registry to use TypeIndex instead of a TypeName string; this should speed up registry lookups. PiperOrigin-RevId: 212478896
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/allocator.cc9
-rw-r--r--tensorflow/core/framework/allocator.h11
-rw-r--r--tensorflow/core/framework/allocator_registry.h1
-rw-r--r--tensorflow/core/framework/attr_value_util_test.cc1
-rw-r--r--tensorflow/core/framework/tensor.h3
-rw-r--r--tensorflow/core/framework/tensor_test.cc1
-rw-r--r--tensorflow/core/framework/tensor_util.h1
-rw-r--r--tensorflow/core/framework/types.h3
-rw-r--r--tensorflow/core/framework/variant.cc25
-rw-r--r--tensorflow/core/framework/variant.h60
-rw-r--r--tensorflow/core/framework/variant_encode_decode.h32
-rw-r--r--tensorflow/core/framework/variant_op_copy_test.cc6
-rw-r--r--tensorflow/core/framework/variant_op_registry.cc85
-rw-r--r--tensorflow/core/framework/variant_op_registry.h216
-rw-r--r--tensorflow/core/framework/variant_op_registry_test.cc96
-rw-r--r--tensorflow/core/framework/variant_tensor_data.cc22
-rw-r--r--tensorflow/core/framework/variant_tensor_data.h10
-rw-r--r--tensorflow/core/framework/variant_test.cc15
18 files changed, 302 insertions, 295 deletions
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index 888ed0c57b..2a7ee16a16 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tracking_allocator.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
@@ -56,6 +57,14 @@ void RunResourceDtor(ResourceHandle* p, size_t n) {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
+void Allocator::RunVariantCtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
+}
+
+void Allocator::RunVariantDtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
+}
+
// If true, cpu allocator collects more stats.
static bool cpu_allocator_collect_stats = false;
// If true, cpu allocator collects full stats.
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 774b1fe137..ded120b704 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -23,12 +23,13 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/type_traits.h"
-#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+class Variant;
+
// Attributes for a single allocation call. Different calls to the same
// allocator could potentially have different allocation attributes.
struct AllocationAttributes {
@@ -228,13 +229,9 @@ class Allocator {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
- virtual void RunVariantCtor(Variant* p, size_t n) {
- for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
- }
+ virtual void RunVariantCtor(Variant* p, size_t n);
- virtual void RunVariantDtor(Variant* p, size_t n) {
- for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
- }
+ virtual void RunVariantDtor(Variant* p, size_t n);
// TODO(jeff): Maybe provide some interface to give info about
// current allocation state (total number of bytes available for
diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h
index 24f282ce84..e907c52ba9 100644
--- a/tensorflow/core/framework/allocator_registry.h
+++ b/tensorflow/core/framework/allocator_registry.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/numa.h"
namespace tensorflow {
diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc
index 1a3994736c..4ffd732f8e 100644
--- a/tensorflow/core/framework/attr_value_util_test.cc
+++ b/tensorflow/core/framework/attr_value_util_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 1b19ab5da3..696fd277cd 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -37,11 +37,12 @@ namespace tensorflow {
class AllocationDescription;
class Allocator;
class OpKernelContext;
+class Tensor;
class TensorBuffer;
class TensorCApi;
class TensorDescription;
class TensorProto;
-class VariantTensorData;
+
namespace batch_util {
Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 84a373c196..9a78cdc91e 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/math/math_util.h"
diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h
index 4bda8f9eb8..a7cf600bab 100644
--- a/tensorflow/core/framework/tensor_util.h
+++ b/tensorflow/core/framework/tensor_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include <vector>
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index 15b1add2c1..2e96b05787 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -39,6 +38,8 @@ limitations under the License.
namespace tensorflow {
+class Variant;
+
// MemoryType is used to describe whether input or output Tensors of
// an OpKernel should reside in "Host memory" (e.g., CPU memory) or
// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc
index 5a507804b0..d43e3c72ec 100644
--- a/tensorflow/core/framework/variant.cc
+++ b/tensorflow/core/framework/variant.cc
@@ -23,11 +23,11 @@ limitations under the License.
namespace tensorflow {
-bool Variant::TryDecode(Variant* out) const {
- const VariantTensorDataProto* p = get<VariantTensorDataProto>();
- if (p == nullptr) return false;
- VariantTensorData data(*p);
- return out->Decode(data);
+bool Variant::Decode(VariantTensorData data) {
+ if (!is_empty()) {
+ return value_->Decode(std::move(data));
+ }
+ return true;
}
template <>
@@ -54,13 +54,12 @@ string TypeNameVariant(const VariantTensorDataProto& value) {
template <>
void EncodeVariant(const VariantTensorDataProto& value,
VariantTensorData* data) {
- data->FromProto(value);
+ data->FromConstProto(value);
}
template <>
-bool DecodeVariant(const VariantTensorData& data,
- VariantTensorDataProto* value) {
- data.ToProto(value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) {
+ data->ToProto(value);
return true;
}
@@ -70,8 +69,8 @@ void EncodeVariant(const VariantTensorDataProto& value, string* buf) {
}
template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value) {
- return value->ParseFromString(buf);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value) {
+ return value->ParseFromString(*buf);
}
void EncodeVariantList(const Variant* variant_array, int64 n,
@@ -93,8 +92,10 @@ bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
if (variant_array[i].is_empty()) {
variant_array[i] = VariantTensorDataProto();
}
+ // TODO(ebrevdo): Replace with StringPiece? Any way to make this a
+ // zero-copy operation that keeps a reference to the data in d?
string str(d->Data(sizes[i]), sizes[i]);
- if (!variant_array[i].Decode(str)) return false;
+ if (!variant_array[i].Decode(std::move(str))) return false;
if (!DecodeUnaryVariant(&variant_array[i])) {
LOG(ERROR) << "Could not decode variant with type_name: \""
<< variant_array[i].TypeName()
diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h
index 52732801a0..10eabbc85f 100644
--- a/tensorflow/core/framework/variant.h
+++ b/tensorflow/core/framework/variant.h
@@ -23,7 +23,6 @@ limitations under the License.
#include <unordered_map>
#include <utility>
-#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/core/status.h"
@@ -38,17 +37,19 @@ string TypeNameVariant(const T& value);
template <typename T>
string DebugStringVariant(const T& value);
+// Allows for specializations of Variant Decoding. `data` may be modified in
+// the process of decoding to `value`.
template <typename T>
-void EncodeVariant(const T& value, VariantTensorData* data);
+bool DecodeVariant(VariantTensorData* data, T* value);
template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value);
+bool DecodeVariant(string* buf, T* value);
template <typename T>
-void EncodeVariant(const T& value, string* buf);
+void EncodeVariant(const T& value, VariantTensorData* data);
template <typename T>
-bool DecodeVariant(const string& buf, T* value);
+void EncodeVariant(const T& value, string* buf);
// This is an implementation of a type-erased container that can store an
// object of any type. The implementation is very similar to std::any, but has
@@ -67,7 +68,7 @@ bool DecodeVariant(const string& buf, T* value);
//
// string TypeName() const;
// void Encode(VariantTensorData* data) const;
-// void Decode(const VariantTensorData& data);
+// void Decode(VariantTensorData data);
//
// Simple POD types can elide the Encode/Decode functions, they are provided by
// helper methods.
@@ -121,7 +122,7 @@ bool DecodeVariant(const string& buf, T* value);
// x.Encode(&serialized_f);
//
// Variant y = Foo(); // default constructed Foo.
-// y.Decode(&serialized_f);
+// y.Decode(std::move(serialized_f));
// EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
//
//
@@ -145,10 +146,6 @@ bool DecodeVariant(const string& buf, T* value);
// EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo.
// EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(),
// y_type_unknown.TypeId());
-// // Decode and get y_type_unknown; compare to value in x.
-// Foo f_decoded;
-// EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded));
-// EXPECT_EQ(f_decoded, f);
//
class Variant {
public:
@@ -241,12 +238,7 @@ class Variant {
}
// Deserialize `data` and update the stored object.
- bool Decode(const VariantTensorData& data) {
- if (!is_empty()) {
- return value_->Decode(data);
- }
- return true;
- }
+ bool Decode(VariantTensorData data);
// Helper methods to directly serialize/deserialize from strings.
void Encode(string* buf) const {
@@ -254,31 +246,13 @@ class Variant {
value_->Encode(buf);
}
}
- bool Decode(const string& buf) {
+ bool Decode(string buf) {
if (!is_empty()) {
- return value_->Decode(buf);
+ return value_->Decode(std::move(buf));
}
return true;
}
- template <typename T>
- bool MaybeDecodeAndCopy(T* out) const {
- const T* ret = get<T>();
- if (ret != nullptr) {
- *out = std::move(*ret);
- return true;
- };
- Variant decoded = T();
- if (!TryDecode(&decoded)) return false;
- T* decoded_ret = decoded.get<T>();
- CHECK_NOTNULL(decoded_ret);
- *out = std::move(*decoded_ret);
- return true;
- }
-
- private:
- bool TryDecode(Variant* out) const;
-
private:
struct in_place_t {};
static constexpr in_place_t in_place{};
@@ -292,9 +266,9 @@ class Variant {
virtual string TypeName() const = 0;
virtual string DebugString() const = 0;
virtual void Encode(VariantTensorData* data) const = 0;
- virtual bool Decode(const VariantTensorData& data) = 0;
+ virtual bool Decode(VariantTensorData data) = 0;
virtual void Encode(string* buf) const = 0;
- virtual bool Decode(const string& data) = 0;
+ virtual bool Decode(string data) = 0;
};
template <typename T>
@@ -325,15 +299,13 @@ class Variant {
EncodeVariant(value, data);
}
- bool Decode(const VariantTensorData& data) override {
- return DecodeVariant(data, &value);
+ bool Decode(VariantTensorData data) override {
+ return DecodeVariant(&data, &value);
}
void Encode(string* buf) const override { EncodeVariant(value, buf); }
- bool Decode(const string& buf) override {
- return DecodeVariant(buf, &value);
- }
+ bool Decode(string buf) override { return DecodeVariant(&buf, &value); }
T value;
};
diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h
index f155aa4892..5e08e5a7a6 100644
--- a/tensorflow/core/framework/variant_encode_decode.h
+++ b/tensorflow/core/framework/variant_encode_decode.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/abi.h"
@@ -81,7 +82,7 @@ void EncodeVariantImpl(const T& value,
// Specialization for POD type
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, true /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
@@ -90,7 +91,7 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for tensorflow::Tensor
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, true /* Tensor */,
false /* protobuf */>,
T* value) {
@@ -100,7 +101,7 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for protobuf
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
true /* protobuf */>,
T* value) {
@@ -111,11 +112,11 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for other types
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
- return value->Decode(data);
+ return value->Decode(std::move(data));
}
template <typename C, typename = void>
@@ -224,8 +225,8 @@ void EncodeVariant(const T& value, VariantTensorData* data) {
}
template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value) {
- return DecodeVariantImpl(data, TypeResolver<T>(), value);
+bool DecodeVariant(VariantTensorData* data, T* value) {
+ return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value);
}
template <typename T>
@@ -238,26 +239,31 @@ void EncodeVariant(const T& value, string* buf) {
}
template <typename T>
-bool DecodeVariant(const string& buf, T* value) {
+bool DecodeVariant(string* buf, T* value) {
VariantTensorData data;
- if (!data.ParseFromString(buf)) return false;
- if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false;
+ if (!data.ParseFromString(*buf)) return false;
+ if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
+ return false;
+ }
return true;
}
// Specializations for VariantTensorDataProto
template <>
string TypeNameVariant(const VariantTensorDataProto& value);
+
template <>
void EncodeVariant(const VariantTensorDataProto& value,
VariantTensorData* data);
+
template <>
-bool DecodeVariant(const VariantTensorData& data,
- VariantTensorDataProto* value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
+
template <>
void EncodeVariant(const VariantTensorDataProto& value, string* buf);
+
template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value);
// Encodes an array of Variant objects in to the given StringListEncoder.
// `variant_array` is assumed to point to an array of `n` Variant objects.
diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc
index 60fa7bd559..daa744e877 100644
--- a/tensorflow/core/framework/variant_op_copy_test.cc
+++ b/tensorflow/core/framework/variant_op_copy_test.cc
@@ -90,15 +90,15 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue");
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
- "StoredTensorValue", StoredTensorValue::CopyCPUToGPU);
+ StoredTensorValue::CopyCPUToGPU);
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_HOST,
- "StoredTensorValue", StoredTensorValue::CopyGPUToCPU);
+ StoredTensorValue::CopyGPUToCPU);
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
- "StoredTensorValue", StoredTensorValue::CopyGPUToGPU);
+ StoredTensorValue::CopyGPUToGPU);
REGISTER_OP("CreateTestVariant")
.Input("input: T")
diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc
index ee07db1aee..ef5b240aea 100644
--- a/tensorflow/core/framework/variant_op_registry.cc
+++ b/tensorflow/core/framework/variant_op_registry.cc
@@ -38,21 +38,19 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
}
UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
- StringPiece type_name) {
- auto found = shape_fns.find(type_name);
+ const TypeIndex& type_index) {
+ auto found = shape_fns.find(type_index);
if (found == shape_fns.end()) return nullptr;
return &found->second;
}
-void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name,
+void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index,
const VariantShapeFn& shape_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape";
- VariantShapeFn* existing = GetShapeFn(type_name);
+ VariantShapeFn* existing = GetShapeFn(type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantShapeFn for type_name: " << type_name
- << " already registered";
- shape_fns.insert(std::pair<StringPiece, VariantShapeFn>(
- GetPersistentStringPiece(type_name), shape_fn));
+ << "Unary VariantShapeFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name()) << " already registered";
+ shape_fns.insert(std::pair<TypeIndex, VariantShapeFn>(type_index, shape_fn));
}
Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
@@ -60,11 +58,11 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
CHECK_EQ(variant_tensor.dims(), 0);
const Variant& v = variant_tensor.scalar<Variant>()();
UnaryVariantOpRegistry::VariantShapeFn* shape_fn =
- UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName());
+ UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId());
if (shape_fn == nullptr) {
return errors::Internal(
- "No unary variant shape function found for Variant type_name: ",
- v.TypeName());
+ "No unary variant shape function found for Variant type_index: ",
+ port::MaybeAbiDemangle(v.TypeId().name()));
}
return (*shape_fn)(v, shape);
}
@@ -79,7 +77,7 @@ Status ScalarShape(const T&, TensorShape* shape) {
} // namespace
#define REGISTER_VARIANT_SHAPE_TYPE(T) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>);
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape<T>);
// No encode/shape registered for std::complex<> and Eigen::half
// objects yet.
@@ -143,25 +141,24 @@ REGISTER_VARIANT_DECODE_TYPE(double);
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn*
UnaryVariantOpRegistry::GetDeviceCopyFn(
- const VariantDeviceCopyDirection direction, StringPiece type_name) {
- auto found = device_copy_fns.find(std::make_pair(direction, type_name));
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
+ auto found = device_copy_fns.find(std::make_pair(direction, type_index));
if (found == device_copy_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterDeviceCopyFn(
- const VariantDeviceCopyDirection direction, const string& type_name,
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
const AsyncVariantDeviceCopyFn& device_copy_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy";
- AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name);
+ AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
CHECK_EQ(existing, nullptr)
<< "UnaryVariantDeviceCopy for direction: " << direction
- << " and type_name: " << type_name << " already registered";
+ << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
+ << " already registered";
device_copy_fns.insert(
- std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>,
- AsyncVariantDeviceCopyFn>(
- std::make_pair(direction, GetPersistentStringPiece(type_name)),
- device_copy_fn));
+ std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+ AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index),
+ device_copy_fn));
}
Status VariantDeviceCopy(
@@ -170,35 +167,34 @@ Status VariantDeviceCopy(
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
- from.TypeName());
+ from.TypeId());
if (device_copy_fn == nullptr) {
return errors::Internal(
"No unary variant device copy function found for direction: ",
- direction, " and Variant type_name: ", from.TypeName());
+ direction, " and Variant type_index: ",
+ port::MaybeAbiDemangle(from.TypeId().name()));
}
return (*device_copy_fn)(from, to, copy_fn);
}
// Special casing UnaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
- VariantUnaryOp op, StringPiece device, StringPiece type_name) {
- auto found = unary_op_fns.find({op, device, type_name});
+ VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) {
+ auto found = unary_op_fns.find({op, device, type_index});
if (found == unary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterUnaryOpFn(
- VariantUnaryOp op, const string& device, const string& type_name,
+ VariantUnaryOp op, const string& device, const TypeIndex& type_index,
const VariantUnaryOpFn& unary_op_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp";
- VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name);
+ VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantUnaryOpFn for type_name: " << type_name
+ << "Unary VariantUnaryOpFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name())
<< " already registered for device type: " << device;
unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
- {op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)},
- unary_op_fn));
+ {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
}
namespace {
@@ -212,7 +208,7 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
- DEVICE_CPU, T, TF_STR(T), \
+ DEVICE_CPU, T, \
ZerosLikeVariantPrimitiveType<T>);
// No zeros_like registered for std::complex<> or Eigen::half objects yet.
@@ -226,24 +222,22 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
// Special casing BinaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantBinaryOpFn*
UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
- StringPiece type_name) {
- auto found = binary_op_fns.find({op, device, type_name});
+ const TypeIndex& type_index) {
+ auto found = binary_op_fns.find({op, device, type_index});
if (found == binary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterBinaryOpFn(
- VariantBinaryOp op, const string& device, const string& type_name,
+ VariantBinaryOp op, const string& device, const TypeIndex& type_index,
const VariantBinaryOpFn& add_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp";
- VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name);
+ VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantBinaryOpFn for type_name: " << type_name
+ << "Unary VariantBinaryOpFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name())
<< " already registered for device type: " << device;
binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
- {op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)},
- add_fn));
+ {op, GetPersistentStringPiece(device), type_index}, add_fn));
}
namespace {
@@ -257,8 +251,7 @@ Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
#define REGISTER_VARIANT_ADD_TYPE(T) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
- T, TF_STR(T), \
- AddVariantPrimitiveType<T>);
+ T, AddVariantPrimitiveType<T>);
// No add registered for std::complex<> or Eigen::half objects yet.
REGISTER_VARIANT_ADD_TYPE(int);
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index e6a2665a56..7eb37e859f 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -22,10 +22,14 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/abi.h"
namespace tensorflow {
@@ -90,10 +94,11 @@ class UnaryVariantOpRegistry {
AsyncVariantDeviceCopyFn;
// Add a shape lookup function to the registry.
- void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
+ void RegisterShapeFn(const TypeIndex& type_index,
+ const VariantShapeFn& shape_fn);
- // Returns nullptr if no shape function was found for the given TypeName.
- VariantShapeFn* GetShapeFn(StringPiece type_name);
+ // Returns nullptr if no shape function was found for the given TypeIndex.
+ VariantShapeFn* GetShapeFn(const TypeIndex& type_index);
// Add a decode function to the registry.
void RegisterDecodeFn(const string& type_name,
@@ -104,33 +109,33 @@ class UnaryVariantOpRegistry {
// Add a copy-to-GPU function to the registry.
void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
- const string& type_name,
+ const TypeIndex& type_index,
const AsyncVariantDeviceCopyFn& device_copy_fn);
// Returns nullptr if no copy function was found for the given
// TypeName and direction.
AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
- const VariantDeviceCopyDirection direction, StringPiece type_name);
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index);
// Add a unary op function to the registry.
void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const VariantUnaryOpFn& unary_op_fn);
// Returns nullptr if no unary op function was found for the given
// op, device, and TypeName.
VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
- StringPiece type_name);
+ const TypeIndex& type_index);
// Add a binary op function to the registry.
void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const VariantBinaryOpFn& add_fn);
// Returns nullptr if no binary op function was found for the given
// op, device and TypeName.
VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
- StringPiece type_name);
+ const TypeIndex& type_index);
// Get a pointer to a global UnaryVariantOpRegistry object
static UnaryVariantOpRegistry* Global();
@@ -145,24 +150,26 @@ class UnaryVariantOpRegistry {
static std::unordered_set<string>* PersistentStringStorage();
private:
- std::unordered_map<StringPiece, VariantShapeFn, StringPieceHasher> shape_fns;
- std::unordered_map<StringPiece, VariantDecodeFn, StringPieceHasher>
- decode_fns;
+ struct TypeIndexHash {
+ std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
+ };
+
+ gtl::FlatMap<TypeIndex, VariantShapeFn, TypeIndexHash> shape_fns;
+ gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
// Map std::pair<Direction, type_name> to function.
struct PairHash {
template <typename Direction>
- std::size_t operator()(const std::pair<Direction, StringPiece>& x) const {
+ std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
- ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
+ ret = Hash64Combine(ret, std::get<1>(x).hash_code());
return ret;
}
- StringPieceHasher sp_hasher_;
};
- std::unordered_map<std::pair<VariantDeviceCopyDirection, StringPiece>,
- AsyncVariantDeviceCopyFn, PairHash>
+ gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+ AsyncVariantDeviceCopyFn, PairHash>
device_copy_fns;
// Map std::tuple<Op, device, type_name> to function.
@@ -172,10 +179,11 @@ class UnaryVariantOpRegistry {
// and references therein
template <typename Op>
struct FuncTuple {
- FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname)
- : op_type_(op), device_(dev), typename_(tname){};
+ FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
+ : op_type_(op), device_(dev), type_index_(type_index) {}
Op op_type_;
- StringPiece device_, typename_;
+ StringPiece device_;
+ TypeIndex type_index_;
};
// friend declaration for operator==
// needed for clang
@@ -184,11 +192,11 @@ class UnaryVariantOpRegistry {
struct TupleHash {
template <typename Op>
std::size_t operator()(
- const std::tuple<Op, StringPiece, StringPiece>& x) const {
+ const std::tuple<Op, StringPiece, TypeIndex>& x) const {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
- ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x)));
+ ret = Hash64Combine(ret, std::get<2>(x).hash_code());
return ret;
}
@@ -197,14 +205,14 @@ class UnaryVariantOpRegistry {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(x.op_type_);
ret = Hash64Combine(ret, sp_hasher_(x.device_));
- ret = Hash64Combine(ret, sp_hasher_(x.typename_));
+ ret = Hash64Combine(ret, x.type_index_.hash_code());
return ret;
}
StringPieceHasher sp_hasher_;
};
- std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
+ gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
unary_op_fns;
- std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
+ gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
binary_op_fns;
// Find or insert a string into a persistent string storage
@@ -225,7 +233,7 @@ template <typename Op>
inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
- (lhs.typename_ == rhs.typename_);
+ (lhs.type_index_ == rhs.type_index_);
}
// Gets a TensorShape from a Tensor containing a scalar Variant.
// Returns an Internal error if the Variant does not have a registered shape
@@ -276,7 +284,7 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
Variant* v_out) {
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
- UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName());
+ UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
if (unary_op_fn == nullptr) {
return errors::Internal(
"No unary variant unary_op function found for unary variant op enum: ",
@@ -297,15 +305,15 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
template <typename Device>
Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
const Variant& a, const Variant& b, Variant* out) {
- if (a.TypeName() != b.TypeName()) {
+ if (a.TypeId() != b.TypeId()) {
return errors::Internal(
"BianryOpVariants: Variants a and b have different "
- "type names: '",
+ "type ids. Type names: '",
a.TypeName(), "' vs. '", b.TypeName(), "'");
}
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
- UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName());
+ UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
if (binary_op_fn == nullptr) {
return errors::Internal(
"No unary variant binary_op function found for binary variant op "
@@ -323,16 +331,18 @@ class UnaryVariantShapeRegistration {
public:
typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn;
- UnaryVariantShapeRegistration(const string& type_name,
+ UnaryVariantShapeRegistration(const TypeIndex& type_index,
const LocalVariantShapeFn& shape_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterShapeFn(
- type_name,
- [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status {
+ type_index,
+ [type_index_name, shape_fn](const Variant& v,
+ TensorShape* s) -> Status {
const T* t = v.get<T>();
if (t == nullptr) {
return errors::Internal(
- "VariantShapeFn: Could not access object, type_name: ",
- type_name);
+ "VariantShapeFn: Could not access object, type_index: ",
+ type_index_name);
}
return shape_fn(*t, s);
});
@@ -355,11 +365,11 @@ class UnaryVariantDecodeRegistration {
return false;
}
Variant decoded = T();
- VariantTensorData data(*t);
- if (!decoded.Decode(data)) {
+ VariantTensorData data(std::move(*t));
+ if (!decoded.Decode(std::move(data))) {
return false;
}
- *v = std::move(decoded);
+ std::swap(decoded, *v);
return true;
});
}
@@ -372,11 +382,12 @@ class UnaryVariantDeviceCopyRegistration {
UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(
- const VariantDeviceCopyDirection direction, const string& type_name,
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
const LocalVariantDeviceCopyFn& device_copy_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
- direction, type_name,
- [type_name, device_copy_fn](
+ direction, type_index,
+ [type_index_name, device_copy_fn](
const Variant& from, Variant* to,
UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
device_copy_tensor_fn) -> Status {
@@ -384,8 +395,8 @@ class UnaryVariantDeviceCopyRegistration {
*to = T();
if (from.get<T>() == nullptr) {
return errors::Internal(
- "VariantCopyToGPUFn: Could not access object, type_name: ",
- type_name);
+ "VariantCopyToGPUFn: Could not access object, type_index: ",
+ type_index_name);
}
const T& t = *from.get<T>();
T* t_out = to->get<T>();
@@ -401,18 +412,19 @@ class UnaryVariantUnaryOpRegistration {
public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const LocalVariantUnaryOpFn& unary_op_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
- op, device, type_name,
- [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
- Variant* v_out) -> Status {
+ op, device, type_index,
+ [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
+ Variant* v_out) -> Status {
DCHECK_NE(v_out, nullptr);
*v_out = T();
if (v.get<T>() == nullptr) {
return errors::Internal(
- "VariantUnaryOpFn: Could not access object, type_name: ",
- type_name);
+ "VariantUnaryOpFn: Could not access object, type_index: ",
+ type_index_name);
}
const T& t = *v.get<T>();
T* t_out = v_out->get<T>();
@@ -429,23 +441,25 @@ class UnaryVariantBinaryOpRegistration {
public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const LocalVariantBinaryOpFn& binary_op_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
- op, device, type_name,
- [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
- const Variant& b, Variant* out) -> Status {
+ op, device, type_index,
+ [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
+ const Variant& b,
+ Variant* out) -> Status {
DCHECK_NE(out, nullptr);
*out = T();
if (a.get<T>() == nullptr) {
return errors::Internal(
- "VariantBinaryOpFn: Could not access object 'a', type_name: ",
- type_name);
+ "VariantBinaryOpFn: Could not access object 'a', type_index: ",
+ type_index_name);
}
if (b.get<T>() == nullptr) {
return errors::Internal(
- "VariantBinaryOpFn: Could not access object 'b', type_name: ",
- type_name);
+ "VariantBinaryOpFn: Could not access object 'b', type_index: ",
+ type_index_name);
}
const T& t_a = *a.get<T>();
const T& t_b = *b.get<T>();
@@ -459,19 +473,19 @@ class UnaryVariantBinaryOpRegistration {
// Register a unary shape variant function with the signature:
// Status ShapeFn(const T& t, TensorShape* s);
-// to Variants having TypeName type_name.
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \
- shape_function)
+// to Variants having TypeIndex type_index.
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, T, MakeTypeIndex<T>(), shape_function)
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \
- shape_function) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function)
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \
+ shape_function) \
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function)
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, \
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, \
shape_function) \
static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \
- register_unary_variant_op_shape_registration_fn_##ctr(type_name, \
+ register_unary_variant_op_shape_registration_fn_##ctr(type_index, \
shape_function)
// Register a unary decode variant function for the given type.
@@ -519,63 +533,63 @@ class UnaryVariantBinaryOpRegistration {
// ****** NOTE ******
// FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
// ****** NOTE ******
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- T, direction, type_name, device_copy_fn) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, T, direction, type_name, device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \
+ device_copy_fn) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn)
#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
- ctr, T, direction, type_name, device_copy_fn) \
+ ctr, T, direction, type_index, device_copy_fn) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
- ctr, T, direction, type_name, device_copy_fn)
+ ctr, T, direction, type_index, device_copy_fn)
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
- ctr, T, direction, type_name, device_copy_fn) \
- static variant_op_registry_fn_registration:: \
- UnaryVariantDeviceCopyRegistration<T> \
- register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \
- device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
+ ctr, T, direction, type_index, device_copy_fn) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantDeviceCopyRegistration<T> \
+ register_unary_variant_op_device_copy_fn_##ctr( \
+ direction, type_index, device_copy_fn)
// Register a unary unary_op variant function with the signature:
// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
// for UnaryVariantOp enum op.
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \
- unary_op_function) \
- REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, op, device, T, type_name, unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \
+ unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function)
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
- ctr, op, device, T, type_name, unary_op_function) \
- REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \
- unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ ctr, op, device, T, type_index, unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
+ type_index, unary_op_function)
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, unary_op_function) \
+ ctr, op, device, T, type_index, unary_op_function) \
static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
T> \
- register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
unary_op_function)
// Register a binary_op variant function with the signature:
// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
// for BinaryVariantOp enum OP.
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \
- binary_op_function) \
- REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, op, device, T, type_name, binary_op_function)
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \
+ binary_op_function) \
+ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function)
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
- ctr, op, device, T, type_name, binary_op_function) \
+ ctr, op, device, T, type_index, binary_op_function) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, binary_op_function)
+ ctr, op, device, T, type_index, binary_op_function)
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, binary_op_function) \
- static variant_op_registry_fn_registration:: \
- UnaryVariantBinaryOpRegistration<T> \
- register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
+ ctr, op, device, T, type_index, binary_op_function) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantBinaryOpRegistration<T> \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
binary_op_function)
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc
index 7055e62c0e..b2443e8676 100644
--- a/tensorflow/core/framework/variant_op_registry_test.cc
+++ b/tensorflow/core/framework/variant_op_registry_test.cc
@@ -89,41 +89,37 @@ struct VariantValue {
int value;
};
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
- VariantValue::ShapeFn);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn);
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
- "TEST VariantValue", VariantValue::CPUToGPUCopyFn);
+ VariantValue::CPUToGPUCopyFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, VariantValue,
- "TEST VariantValue",
VariantValue::CPUZerosLikeFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, VariantValue,
- "TEST VariantValue",
VariantValue::GPUZerosLikeFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- VariantValue, "TEST VariantValue",
- VariantValue::CPUAddFn);
+ VariantValue, VariantValue::CPUAddFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- VariantValue, "TEST VariantValue",
- VariantValue::GPUAddFn);
+ VariantValue, VariantValue::GPUAddFn);
} // namespace
TEST(VariantOpShapeRegistryTest, TestBasic) {
- EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"),
+ class Blah {};
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()),
nullptr);
- auto* shape_fn =
- UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue");
+ auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn(
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(shape_fn, nullptr);
TensorShape shape;
@@ -142,10 +138,11 @@ TEST(VariantOpShapeRegistryTest, TestBasic) {
TEST(VariantOpShapeRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantShapeFn f;
- string kTypeName = "fjfjfj";
- registry.RegisterShapeFn(kTypeName, f);
- EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f),
- "fjfjfj already registered");
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
+ registry.RegisterShapeFn(kTypeIndex, f);
+ EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpDecodeRegistryTest, TestBasic) {
@@ -180,13 +177,14 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
// No registered copy fn for GPU<->GPU.
- EXPECT_EQ(
- UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"),
- nullptr);
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
+ VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
+ MakeTypeIndex<VariantValue>()),
+ nullptr);
auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE,
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(copy_to_gpu_fn, nullptr);
VariantValue vv{true /* early_exit */};
@@ -208,17 +206,19 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE,
- kTypeName, f);
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f),
- "fjfjfj already registered");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -242,8 +242,9 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -269,25 +270,26 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantUnaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_CPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_CPU, kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_GPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_GPU, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpAddRegistryTest, TestBasicCPU) {
- return;
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -312,8 +314,9 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpAddRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -340,17 +343,18 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) {
TEST(VariantOpAddRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantBinaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/variant_tensor_data.cc b/tensorflow/core/framework/variant_tensor_data.cc
index 99712dc114..3e67e4a864 100644
--- a/tensorflow/core/framework/variant_tensor_data.cc
+++ b/tensorflow/core/framework/variant_tensor_data.cc
@@ -22,8 +22,8 @@ namespace tensorflow {
VariantTensorData::VariantTensorData() {}
-VariantTensorData::VariantTensorData(const VariantTensorDataProto& proto) {
- FromProto(proto);
+VariantTensorData::VariantTensorData(VariantTensorDataProto proto) {
+ FromProto(std::move(proto));
}
VariantTensorData::~VariantTensorData() {}
@@ -52,7 +52,19 @@ void VariantTensorData::ToProto(VariantTensorDataProto* proto) const {
}
}
-bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) {
+bool VariantTensorData::FromProto(VariantTensorDataProto proto) {
+ // TODO(ebrevdo): Do this lazily.
+ set_type_name(proto.type_name());
+ set_metadata(proto.metadata());
+ for (const auto& tensor : proto.tensors()) {
+ Tensor tmp;
+ if (!tmp.FromProto(tensor)) return false;
+ tensors_.push_back(tmp);
+ }
+ return true;
+}
+
+bool VariantTensorData::FromConstProto(const VariantTensorDataProto& proto) {
set_type_name(proto.type_name());
set_metadata(proto.metadata());
for (const auto& tensor : proto.tensors()) {
@@ -75,10 +87,10 @@ bool VariantTensorData::SerializeToString(string* buf) {
return proto.SerializeToString(buf);
}
-bool VariantTensorData::ParseFromString(const string& s) {
+bool VariantTensorData::ParseFromString(string s) {
VariantTensorDataProto proto;
const bool status = proto.ParseFromString(s);
- if (status) FromProto(proto);
+ if (status) FromProto(std::move(proto));
return status;
}
diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h
index 7500e77d43..8a240ee1e3 100644
--- a/tensorflow/core/framework/variant_tensor_data.h
+++ b/tensorflow/core/framework/variant_tensor_data.h
@@ -19,13 +19,13 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class VariantTensorDataProto;
-class Tensor;
// The serialization format for Variant objects. Objects with references to
// other Tensors can simply store those tensors in the `tensors` field, and
@@ -38,7 +38,7 @@ class Tensor;
class VariantTensorData {
public:
VariantTensorData();
- VariantTensorData(const VariantTensorDataProto& proto);
+ VariantTensorData(VariantTensorDataProto proto);
~VariantTensorData();
// Name of the type of objects being serialized.
@@ -68,12 +68,14 @@ class VariantTensorData {
// Conversion to and from VariantTensorDataProto
void ToProto(VariantTensorDataProto* proto) const;
- bool FromProto(const VariantTensorDataProto& proto);
+ // This allows optimizations via std::move.
+ bool FromProto(VariantTensorDataProto proto);
+ bool FromConstProto(const VariantTensorDataProto& proto);
// Serialization via VariantTensorDataProto
string SerializeAsString() const;
bool SerializeToString(string* buf);
- bool ParseFromString(const string& s);
+ bool ParseFromString(string s);
string DebugString() const;
diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc
index eef5c47d15..08d09de7b8 100644
--- a/tensorflow/core/framework/variant_test.cc
+++ b/tensorflow/core/framework/variant_test.cc
@@ -144,8 +144,8 @@ TEST(VariantTest, TypeMismatch) {
struct TensorList {
void Encode(VariantTensorData* data) const { data->tensors_ = vec; }
- bool Decode(const VariantTensorData& data) {
- vec = data.tensors_;
+ bool Decode(VariantTensorData data) {
+ vec = std::move(data.tensors_);
return true;
}
@@ -186,7 +186,7 @@ TEST(VariantTest, TensorListTest) {
x.Encode(&serialized);
Variant y = TensorList();
- y.Decode(serialized);
+ y.Decode(std::move(serialized));
const TensorList& decoded_vec = *y.get<TensorList>();
for (int i = 0; i < 4; ++i) {
@@ -204,15 +204,6 @@ TEST(VariantTest, TensorListTest) {
EXPECT_EQ(y_unknown.DebugString(),
strings::StrCat(
"Variant<type: TensorList value: ", data.DebugString(), ">"));
-
- TensorList unknown_decoded_vec;
- EXPECT_TRUE(y_unknown.MaybeDecodeAndCopy(&unknown_decoded_vec));
- for (int i = 0; i < 4; ++i) {
- EXPECT_EQ(unknown_decoded_vec.vec[i].flat<int>()(0), i);
- }
- for (int i = 0; i < 4; ++i) {
- EXPECT_EQ(unknown_decoded_vec.vec[i + 4].flat<float>()(0), 2 * i);
- }
}
TEST(VariantTest, VariantArray) {