diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-09-11 10:41:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-11 10:51:01 -0700 |
commit | 36e1a5ea5ba2dd5eaa7f4cfc84a61f8ce3ea20e1 (patch) | |
tree | 4f1671f78f5971b02dc2af66f57eabbf01005112 /tensorflow/core/framework/variant_op_registry.cc | |
parent | 36d7b12357df667dcd427c070e21779ed83f4ec9 (diff) |
[TF] Variant improvements.
1. Change Variant Decode to accept VariantTensorData (non-ref).
This should allow some optimization in the future.
In the meantime it means removing the variant.h include from tensor.h, since
variant_encode_decode.h now relies on tensor.h and variant.h now relies on that.
It also means we found a bunch of places where tensor.proto.h, variant.h, and
mutex.h were being imported through tensor.h (along with a bunch of other crap);
so now we directly import them in order to compile.
2. Move Variant registry to use TypeIndex instead of a TypeName string; this should
speed up registry lookups.
PiperOrigin-RevId: 212478896
Diffstat (limited to 'tensorflow/core/framework/variant_op_registry.cc')
-rw-r--r-- | tensorflow/core/framework/variant_op_registry.cc | 85 |
1 files changed, 39 insertions, 46 deletions
diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc index ee07db1aee..ef5b240aea 100644 --- a/tensorflow/core/framework/variant_op_registry.cc +++ b/tensorflow/core/framework/variant_op_registry.cc @@ -38,21 +38,19 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() { } UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn( - StringPiece type_name) { - auto found = shape_fns.find(type_name); + const TypeIndex& type_index) { + auto found = shape_fns.find(type_index); if (found == shape_fns.end()) return nullptr; return &found->second; } -void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name, +void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index, const VariantShapeFn& shape_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape"; - VariantShapeFn* existing = GetShapeFn(type_name); + VariantShapeFn* existing = GetShapeFn(type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantShapeFn for type_name: " << type_name - << " already registered"; - shape_fns.insert(std::pair<StringPiece, VariantShapeFn>( - GetPersistentStringPiece(type_name), shape_fn)); + << "Unary VariantShapeFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered"; + shape_fns.insert(std::pair<TypeIndex, VariantShapeFn>(type_index, shape_fn)); } Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { @@ -60,11 +58,11 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { CHECK_EQ(variant_tensor.dims(), 0); const Variant& v = variant_tensor.scalar<Variant>()(); UnaryVariantOpRegistry::VariantShapeFn* shape_fn = - UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName()); + UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId()); if (shape_fn == nullptr) { return errors::Internal( - "No unary variant shape function found for Variant type_name: ", - v.TypeName()); + "No unary variant shape function found for Variant type_index: ", + port::MaybeAbiDemangle(v.TypeId().name())); } return (*shape_fn)(v, shape); } @@ -79,7 +77,7 @@ Status ScalarShape(const T&, TensorShape* shape) { } // namespace #define REGISTER_VARIANT_SHAPE_TYPE(T) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>); + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape<T>); // No encode/shape registered for std::complex<> and Eigen::half // objects yet. @@ -143,25 +141,24 @@ REGISTER_VARIANT_DECODE_TYPE(double); UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* UnaryVariantOpRegistry::GetDeviceCopyFn( - const VariantDeviceCopyDirection direction, StringPiece type_name) { - auto found = device_copy_fns.find(std::make_pair(direction, type_name)); + const VariantDeviceCopyDirection direction, const TypeIndex& type_index) { + auto found = device_copy_fns.find(std::make_pair(direction, type_index)); if (found == device_copy_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterDeviceCopyFn( - const VariantDeviceCopyDirection direction, const string& type_name, + const VariantDeviceCopyDirection direction, const TypeIndex& type_index, const AsyncVariantDeviceCopyFn& device_copy_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy"; - AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name); + AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index); CHECK_EQ(existing, nullptr) << "UnaryVariantDeviceCopy for direction: " << direction - << " and type_name: " << type_name << " already registered"; + << " and type_index: " << port::MaybeAbiDemangle(type_index.name()) + << " already registered"; device_copy_fns.insert( - std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>, - AsyncVariantDeviceCopyFn>( - std::make_pair(direction, GetPersistentStringPiece(type_name)), - device_copy_fn)); + std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>, + AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index), + device_copy_fn)); } Status VariantDeviceCopy( @@ -170,35 +167,34 @@ Status VariantDeviceCopy( const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) { UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction, - from.TypeName()); + from.TypeId()); if (device_copy_fn == nullptr) { return errors::Internal( "No unary variant device copy function found for direction: ", - direction, " and Variant type_name: ", from.TypeName()); + direction, " and Variant type_index: ", + port::MaybeAbiDemangle(from.TypeId().name())); } return (*device_copy_fn)(from, to, copy_fn); } // Special casing UnaryOpFn per op and per device. UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn( - VariantUnaryOp op, StringPiece device, StringPiece type_name) { - auto found = unary_op_fns.find({op, device, type_name}); + VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) { + auto found = unary_op_fns.find({op, device, type_index}); if (found == unary_op_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterUnaryOpFn( - VariantUnaryOp op, const string& device, const string& type_name, + VariantUnaryOp op, const string& device, const TypeIndex& type_index, const VariantUnaryOpFn& unary_op_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp"; - VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name); + VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantUnaryOpFn for type_name: " << type_name + << "Unary VariantUnaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered for device type: " << device; unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>( - {op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)}, - unary_op_fn)); + {op, GetPersistentStringPiece(device), type_index}, unary_op_fn)); } namespace { @@ -212,7 +208,7 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t, #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \ - DEVICE_CPU, T, TF_STR(T), \ + DEVICE_CPU, T, \ ZerosLikeVariantPrimitiveType<T>); // No zeros_like registered for std::complex<> or Eigen::half objects yet. @@ -226,24 +222,22 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool); // Special casing BinaryOpFn per op and per device. UnaryVariantOpRegistry::VariantBinaryOpFn* UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device, - StringPiece type_name) { - auto found = binary_op_fns.find({op, device, type_name}); + const TypeIndex& type_index) { + auto found = binary_op_fns.find({op, device, type_index}); if (found == binary_op_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterBinaryOpFn( - VariantBinaryOp op, const string& device, const string& type_name, + VariantBinaryOp op, const string& device, const TypeIndex& type_index, const VariantBinaryOpFn& add_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp"; - VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name); + VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantBinaryOpFn for type_name: " << type_name + << "Unary VariantBinaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered for device type: " << device; binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>( - {op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)}, - add_fn)); + {op, GetPersistentStringPiece(device), type_index}, add_fn)); } namespace { @@ -257,8 +251,7 @@ Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b, #define REGISTER_VARIANT_ADD_TYPE(T) \ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \ - T, TF_STR(T), \ - AddVariantPrimitiveType<T>); + T, AddVariantPrimitiveType<T>); // No add registered for std::complex<> or Eigen::half objects yet. REGISTER_VARIANT_ADD_TYPE(int); |