diff options
Diffstat (limited to 'ruby/src/main/java/com')
4 files changed, 91 insertions, 46 deletions
diff --git a/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java b/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java index 2d4c03b5..3adaa2a8 100644 --- a/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java +++ b/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java @@ -148,8 +148,8 @@ public class RubyMap extends RubyObject { */ @JRubyMethod(name = "[]=") public IRubyObject indexSet(ThreadContext context, IRubyObject key, IRubyObject value) { - Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass); - Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass); + key = Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass); + value = Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass); IRubyObject symbol; if (valueType == Descriptors.FieldDescriptor.Type.ENUM && Utils.isRubyNum(value) && diff --git a/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java b/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java index 39213c4d..07558fbc 100644 --- a/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java +++ b/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java @@ -41,6 +41,8 @@ import org.jruby.runtime.ThreadContext; import org.jruby.runtime.builtin.IRubyObject; import org.jruby.util.ByteList; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.util.HashMap; import java.util.Map; @@ -80,8 +82,8 @@ public class RubyMessage extends RubyObject { hash.visitAll(new RubyHash.Visitor() { @Override public void visit(IRubyObject key, IRubyObject value) { - if (!(key instanceof RubySymbol)) - throw runtime.newTypeError("Expected symbols as hash keys in initialization map."); + if (!(key instanceof RubySymbol) && !(key instanceof RubyString)) + throw runtime.newTypeError("Expected string or symbols as hash keys in initialization map."); final Descriptors.FieldDescriptor fieldDescriptor = findField(context, key); if (Utils.isMapEntry(fieldDescriptor)) { @@ -101,9 +103,15 @@ public class RubyMessage extends RubyObject { if (oneof != null) { oneofCases.put(oneof, fieldDescriptor); } + + if (value instanceof RubyHash && fieldDescriptor.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) { + RubyDescriptor descriptor = (RubyDescriptor) getDescriptorForField(context, fieldDescriptor); + RubyClass typeClass = (RubyClass) descriptor.msgclass(context); + value = (IRubyObject) typeClass.newInstance(context, value, Block.NULL_BLOCK); + } + fields.put(fieldDescriptor, value); } - } }); } @@ -164,8 +172,21 @@ public class RubyMessage extends RubyObject { */ @JRubyMethod public IRubyObject hash(ThreadContext context) { - int hashCode = System.identityHashCode(this); - return context.runtime.newFixnum(hashCode); + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + for (RubyMap map : maps.values()) { + digest.update((byte) map.hashCode()); + } + for (RubyRepeatedField repeatedField : repeatedFields.values()) { + digest.update((byte) repeatedFields.hashCode()); + } + for (IRubyObject field : fields.values()) { + digest.update((byte) field.hashCode()); + } + return context.runtime.newString(new ByteList(digest.digest())); + } catch (NoSuchAlgorithmException ignore) { + return context.runtime.newFixnum(System.identityHashCode(this)); + } } /* @@ -352,7 +373,19 @@ public class RubyMessage extends RubyObject { for (Descriptors.FieldDescriptor fdef : this.descriptor.getFields()) { IRubyObject value = getField(context, fdef); if (!value.isNil()) { - if (value.respondsTo("to_h")) { + if (fdef.isRepeated() && !fdef.isMapField()) { + if (fdef.getType() != Descriptors.FieldDescriptor.Type.MESSAGE) { + value = Helpers.invoke(context, value, "to_a"); + } else { + RubyArray ary = value.convertToArray(); + for (int i = 0; i < ary.size(); i++) { + IRubyObject submsg = Helpers.invoke(context, ary.eltInternal(i), "to_h"); + ary.eltInternalSet(i, submsg); + } + + value = ary.to_ary(); + } + } else if (value.respondsTo("to_h")) { value = Helpers.invoke(context, value, "to_h"); } else if (value.respondsTo("to_a")) { value = Helpers.invoke(context, value, "to_a"); @@ -503,19 +536,12 @@ public class RubyMessage extends RubyObject { val = value.isTrue(); break; case BYTES: + Utils.validateStringEncoding(context, fieldDescriptor.getType(), value); + val = ByteString.copyFrom(((RubyString) value).getBytes()); + break; case STRING: - Utils.validateStringEncoding(context.runtime, fieldDescriptor.getType(), value); - RubyString str = (RubyString) value; - switch (fieldDescriptor.getType()) { - case BYTES: - val = ByteString.copyFrom(str.getBytes()); - break; - case STRING: - val = str.asJavaString(); - break; - default: - break; - } + Utils.validateStringEncoding(context, fieldDescriptor.getType(), value); + val = ((RubyString) value).asJavaString(); break; case MESSAGE: RubyClass typeClass = (RubyClass) ((RubyDescriptor) getDescriptorForField(context, fieldDescriptor)).msgclass(context); @@ -528,7 +554,7 @@ public class RubyMessage extends RubyObject { if (Utils.isRubyNum(value)) { val = enumDescriptor.findValueByNumberCreatingIfUnknown(RubyNumeric.num2int(value)); - } else if (value instanceof RubySymbol) { + } else if (value instanceof RubySymbol || value instanceof RubyString) { val = enumDescriptor.findValueByName(value.asJavaString()); } else { throw runtime.newTypeError("Expected number or symbol type for enum field."); @@ -592,13 +618,17 @@ public class RubyMessage extends RubyObject { protected IRubyObject getField(ThreadContext context, Descriptors.FieldDescriptor fieldDescriptor) { Descriptors.OneofDescriptor oneofDescriptor = fieldDescriptor.getContainingOneof(); if (oneofDescriptor != null) { - if (oneofCases.containsKey(oneofDescriptor)) { - if (oneofCases.get(oneofDescriptor) != fieldDescriptor) - return context.runtime.getNil(); + if (oneofCases.get(oneofDescriptor) == fieldDescriptor) { return fields.get(fieldDescriptor); } else { Descriptors.FieldDescriptor oneofCase = builder.getOneofFieldDescriptor(oneofDescriptor); - if (oneofCase != fieldDescriptor) return context.runtime.getNil(); + if (oneofCase != fieldDescriptor) { + if (fieldDescriptor.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) { + return context.runtime.getNil(); + } else { + return wrapField(context, fieldDescriptor, fieldDescriptor.getDefaultValue()); + } + } IRubyObject value = wrapField(context, oneofCase, builder.getField(oneofCase)); fields.put(fieldDescriptor, value); return value; @@ -691,7 +721,7 @@ public class RubyMessage extends RubyObject { } } if (addValue) { - Utils.checkType(context, fieldType, value, (RubyModule) typeClass); + value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass); this.fields.put(fieldDescriptor, value); } else { this.fields.remove(fieldDescriptor); @@ -722,8 +752,20 @@ public class RubyMessage extends RubyObject { Descriptors.FieldDescriptor fieldDescriptor, IRubyObject value) { RubyArray arr = value.convertToArray(); RubyRepeatedField repeatedField = repeatedFieldForFieldDescriptor(context, fieldDescriptor); + + RubyClass typeClass = null; + if (fieldDescriptor.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) { + RubyDescriptor descriptor = (RubyDescriptor) getDescriptorForField(context, fieldDescriptor); + typeClass = (RubyClass) descriptor.msgclass(context); + } + for (int i = 0; i < arr.size(); i++) { - repeatedField.push(context, arr.eltInternal(i)); + IRubyObject row = arr.eltInternal(i); + if (row instanceof RubyHash && typeClass != null) { + row = (IRubyObject) typeClass.newInstance(context, row, Block.NULL_BLOCK); + } + + repeatedField.push(context, row); } return repeatedField; } diff --git a/ruby/src/main/java/com/google/protobuf/jruby/RubyRepeatedField.java b/ruby/src/main/java/com/google/protobuf/jruby/RubyRepeatedField.java index 946f9e74..ae2907a9 100644 --- a/ruby/src/main/java/com/google/protobuf/jruby/RubyRepeatedField.java +++ b/ruby/src/main/java/com/google/protobuf/jruby/RubyRepeatedField.java @@ -110,7 +110,7 @@ public class RubyRepeatedField extends RubyObject { @JRubyMethod(name = "[]=") public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) { int arrIndex = normalizeArrayIndex(index); - Utils.checkType(context, fieldType, value, (RubyModule) typeClass); + value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass); IRubyObject defaultValue = defaultValue(context); for (int i = this.storage.size(); i < arrIndex; i++) { this.storage.set(i, defaultValue); @@ -166,7 +166,7 @@ public class RubyRepeatedField extends RubyObject { public IRubyObject push(ThreadContext context, IRubyObject value) { if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE && value == context.runtime.getNil())) { - Utils.checkType(context, fieldType, value, (RubyModule) typeClass); + value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass); } this.storage.add(value); return this.storage; diff --git a/ruby/src/main/java/com/google/protobuf/jruby/Utils.java b/ruby/src/main/java/com/google/protobuf/jruby/Utils.java index 596a0979..f199feb9 100644 --- a/ruby/src/main/java/com/google/protobuf/jruby/Utils.java +++ b/ruby/src/main/java/com/google/protobuf/jruby/Utils.java @@ -64,8 +64,8 @@ public class Utils { return context.runtime.newSymbol(typeName.replace("TYPE_", "").toLowerCase()); } - public static void checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, - IRubyObject value, RubyModule typeClass) { + public static IRubyObject checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, + IRubyObject value, RubyModule typeClass) { Ruby runtime = context.runtime; Object val; switch(fieldType) { @@ -106,7 +106,7 @@ public class Utils { break; case BYTES: case STRING: - validateStringEncoding(context.runtime, fieldType, value); + value = validateStringEncoding(context, fieldType, value); break; case MESSAGE: if (value.getMetaClass() != typeClass) { @@ -127,6 +127,7 @@ public class Utils { default: break; } + return value; } public static IRubyObject wrapPrimaryValue(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, Object value) { @@ -148,10 +149,16 @@ public class Utils { return runtime.newFloat((Double) value); case BOOL: return (Boolean) value ? runtime.getTrue() : runtime.getFalse(); - case BYTES: - return runtime.newString(((ByteString) value).toStringUtf8()); - case STRING: - return runtime.newString(value.toString()); + case BYTES: { + IRubyObject wrapped = runtime.newString(((ByteString) value).toStringUtf8()); + wrapped.setFrozen(true); + return wrapped; + } + case STRING: { + IRubyObject wrapped = runtime.newString(value.toString()); + wrapped.setFrozen(true); + return wrapped; + } default: return runtime.getNil(); } @@ -180,25 +187,21 @@ public class Utils { } } - public static void validateStringEncoding(Ruby runtime, Descriptors.FieldDescriptor.Type type, IRubyObject value) { + public static IRubyObject validateStringEncoding(ThreadContext context, Descriptors.FieldDescriptor.Type type, IRubyObject value) { if (!(value instanceof RubyString)) - throw runtime.newTypeError("Invalid argument for string field."); - Encoding encoding = ((RubyString) value).getEncoding(); + throw context.runtime.newTypeError("Invalid argument for string field."); switch(type) { case BYTES: - if (encoding != ASCIIEncoding.INSTANCE) - throw runtime.newTypeError("Encoding for bytes fields" + - " must be \"ASCII-8BIT\", but was " + encoding); + value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::ASCII_8BIT")); break; case STRING: - if (encoding != UTF8Encoding.INSTANCE - && encoding != USASCIIEncoding.INSTANCE) - throw runtime.newTypeError("Encoding for string fields" + - " must be \"UTF-8\" or \"ASCII\", but was " + encoding); + value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::UTF_8")); break; default: break; } + value.setFrozen(true); + return value; } public static void checkNameAvailability(ThreadContext context, String name) { |