aboutsummaryrefslogtreecommitdiffhomepage
path: root/ruby/src/main/java/com
diff options
context:
space:
mode:
Diffstat (limited to 'ruby/src/main/java/com')
-rw-r--r--ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java4
-rw-r--r--ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java92
-rw-r--r--ruby/src/main/java/com/google/protobuf/jruby/RubyRepeatedField.java4
-rw-r--r--ruby/src/main/java/com/google/protobuf/jruby/Utils.java37
4 files changed, 91 insertions, 46 deletions
diff --git a/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java b/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java
index 2d4c03b5..3adaa2a8 100644
--- a/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java
+++ b/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java
@@ -148,8 +148,8 @@ public class RubyMap extends RubyObject {
*/
@JRubyMethod(name = "[]=")
public IRubyObject indexSet(ThreadContext context, IRubyObject key, IRubyObject value) {
- Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
- Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
+ key = Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
+ value = Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
IRubyObject symbol;
if (valueType == Descriptors.FieldDescriptor.Type.ENUM &&
Utils.isRubyNum(value) &&
diff --git a/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java b/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java
index 39213c4d..07558fbc 100644
--- a/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java
+++ b/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java
@@ -41,6 +41,8 @@ import org.jruby.runtime.ThreadContext;
import org.jruby.runtime.builtin.IRubyObject;
import org.jruby.util.ByteList;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
import java.util.HashMap;
import java.util.Map;
@@ -80,8 +82,8 @@ public class RubyMessage extends RubyObject {
hash.visitAll(new RubyHash.Visitor() {
@Override
public void visit(IRubyObject key, IRubyObject value) {
- if (!(key instanceof RubySymbol))
- throw runtime.newTypeError("Expected symbols as hash keys in initialization map.");
+ if (!(key instanceof RubySymbol) && !(key instanceof RubyString))
+ throw runtime.newTypeError("Expected string or symbols as hash keys in initialization map.");
final Descriptors.FieldDescriptor fieldDescriptor = findField(context, key);
if (Utils.isMapEntry(fieldDescriptor)) {
@@ -101,9 +103,15 @@ public class RubyMessage extends RubyObject {
if (oneof != null) {
oneofCases.put(oneof, fieldDescriptor);
}
+
+ if (value instanceof RubyHash && fieldDescriptor.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) {
+ RubyDescriptor descriptor = (RubyDescriptor) getDescriptorForField(context, fieldDescriptor);
+ RubyClass typeClass = (RubyClass) descriptor.msgclass(context);
+ value = (IRubyObject) typeClass.newInstance(context, value, Block.NULL_BLOCK);
+ }
+
fields.put(fieldDescriptor, value);
}
-
}
});
}
@@ -164,8 +172,21 @@ public class RubyMessage extends RubyObject {
*/
@JRubyMethod
public IRubyObject hash(ThreadContext context) {
- int hashCode = System.identityHashCode(this);
- return context.runtime.newFixnum(hashCode);
+ try {
+ MessageDigest digest = MessageDigest.getInstance("SHA-256");
+ for (RubyMap map : maps.values()) {
+ digest.update((byte) map.hashCode());
+ }
+ for (RubyRepeatedField repeatedField : repeatedFields.values()) {
+ digest.update((byte) repeatedFields.hashCode());
+ }
+ for (IRubyObject field : fields.values()) {
+ digest.update((byte) field.hashCode());
+ }
+ return context.runtime.newString(new ByteList(digest.digest()));
+ } catch (NoSuchAlgorithmException ignore) {
+ return context.runtime.newFixnum(System.identityHashCode(this));
+ }
}
/*
@@ -352,7 +373,19 @@ public class RubyMessage extends RubyObject {
for (Descriptors.FieldDescriptor fdef : this.descriptor.getFields()) {
IRubyObject value = getField(context, fdef);
if (!value.isNil()) {
- if (value.respondsTo("to_h")) {
+ if (fdef.isRepeated() && !fdef.isMapField()) {
+ if (fdef.getType() != Descriptors.FieldDescriptor.Type.MESSAGE) {
+ value = Helpers.invoke(context, value, "to_a");
+ } else {
+ RubyArray ary = value.convertToArray();
+ for (int i = 0; i < ary.size(); i++) {
+ IRubyObject submsg = Helpers.invoke(context, ary.eltInternal(i), "to_h");
+ ary.eltInternalSet(i, submsg);
+ }
+
+ value = ary.to_ary();
+ }
+ } else if (value.respondsTo("to_h")) {
value = Helpers.invoke(context, value, "to_h");
} else if (value.respondsTo("to_a")) {
value = Helpers.invoke(context, value, "to_a");
@@ -503,19 +536,12 @@ public class RubyMessage extends RubyObject {
val = value.isTrue();
break;
case BYTES:
+ Utils.validateStringEncoding(context, fieldDescriptor.getType(), value);
+ val = ByteString.copyFrom(((RubyString) value).getBytes());
+ break;
case STRING:
- Utils.validateStringEncoding(context.runtime, fieldDescriptor.getType(), value);
- RubyString str = (RubyString) value;
- switch (fieldDescriptor.getType()) {
- case BYTES:
- val = ByteString.copyFrom(str.getBytes());
- break;
- case STRING:
- val = str.asJavaString();
- break;
- default:
- break;
- }
+ Utils.validateStringEncoding(context, fieldDescriptor.getType(), value);
+ val = ((RubyString) value).asJavaString();
break;
case MESSAGE:
RubyClass typeClass = (RubyClass) ((RubyDescriptor) getDescriptorForField(context, fieldDescriptor)).msgclass(context);
@@ -528,7 +554,7 @@ public class RubyMessage extends RubyObject {
if (Utils.isRubyNum(value)) {
val = enumDescriptor.findValueByNumberCreatingIfUnknown(RubyNumeric.num2int(value));
- } else if (value instanceof RubySymbol) {
+ } else if (value instanceof RubySymbol || value instanceof RubyString) {
val = enumDescriptor.findValueByName(value.asJavaString());
} else {
throw runtime.newTypeError("Expected number or symbol type for enum field.");
@@ -592,13 +618,17 @@ public class RubyMessage extends RubyObject {
protected IRubyObject getField(ThreadContext context, Descriptors.FieldDescriptor fieldDescriptor) {
Descriptors.OneofDescriptor oneofDescriptor = fieldDescriptor.getContainingOneof();
if (oneofDescriptor != null) {
- if (oneofCases.containsKey(oneofDescriptor)) {
- if (oneofCases.get(oneofDescriptor) != fieldDescriptor)
- return context.runtime.getNil();
+ if (oneofCases.get(oneofDescriptor) == fieldDescriptor) {
return fields.get(fieldDescriptor);
} else {
Descriptors.FieldDescriptor oneofCase = builder.getOneofFieldDescriptor(oneofDescriptor);
- if (oneofCase != fieldDescriptor) return context.runtime.getNil();
+ if (oneofCase != fieldDescriptor) {
+ if (fieldDescriptor.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) {
+ return context.runtime.getNil();
+ } else {
+ return wrapField(context, fieldDescriptor, fieldDescriptor.getDefaultValue());
+ }
+ }
IRubyObject value = wrapField(context, oneofCase, builder.getField(oneofCase));
fields.put(fieldDescriptor, value);
return value;
@@ -691,7 +721,7 @@ public class RubyMessage extends RubyObject {
}
}
if (addValue) {
- Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
+ value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
this.fields.put(fieldDescriptor, value);
} else {
this.fields.remove(fieldDescriptor);
@@ -722,8 +752,20 @@ public class RubyMessage extends RubyObject {
Descriptors.FieldDescriptor fieldDescriptor, IRubyObject value) {
RubyArray arr = value.convertToArray();
RubyRepeatedField repeatedField = repeatedFieldForFieldDescriptor(context, fieldDescriptor);
+
+ RubyClass typeClass = null;
+ if (fieldDescriptor.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) {
+ RubyDescriptor descriptor = (RubyDescriptor) getDescriptorForField(context, fieldDescriptor);
+ typeClass = (RubyClass) descriptor.msgclass(context);
+ }
+
for (int i = 0; i < arr.size(); i++) {
- repeatedField.push(context, arr.eltInternal(i));
+ IRubyObject row = arr.eltInternal(i);
+ if (row instanceof RubyHash && typeClass != null) {
+ row = (IRubyObject) typeClass.newInstance(context, row, Block.NULL_BLOCK);
+ }
+
+ repeatedField.push(context, row);
}
return repeatedField;
}
diff --git a/ruby/src/main/java/com/google/protobuf/jruby/RubyRepeatedField.java b/ruby/src/main/java/com/google/protobuf/jruby/RubyRepeatedField.java
index 946f9e74..ae2907a9 100644
--- a/ruby/src/main/java/com/google/protobuf/jruby/RubyRepeatedField.java
+++ b/ruby/src/main/java/com/google/protobuf/jruby/RubyRepeatedField.java
@@ -110,7 +110,7 @@ public class RubyRepeatedField extends RubyObject {
@JRubyMethod(name = "[]=")
public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) {
int arrIndex = normalizeArrayIndex(index);
- Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
+ value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
IRubyObject defaultValue = defaultValue(context);
for (int i = this.storage.size(); i < arrIndex; i++) {
this.storage.set(i, defaultValue);
@@ -166,7 +166,7 @@ public class RubyRepeatedField extends RubyObject {
public IRubyObject push(ThreadContext context, IRubyObject value) {
if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE &&
value == context.runtime.getNil())) {
- Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
+ value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
}
this.storage.add(value);
return this.storage;
diff --git a/ruby/src/main/java/com/google/protobuf/jruby/Utils.java b/ruby/src/main/java/com/google/protobuf/jruby/Utils.java
index 596a0979..f199feb9 100644
--- a/ruby/src/main/java/com/google/protobuf/jruby/Utils.java
+++ b/ruby/src/main/java/com/google/protobuf/jruby/Utils.java
@@ -64,8 +64,8 @@ public class Utils {
return context.runtime.newSymbol(typeName.replace("TYPE_", "").toLowerCase());
}
- public static void checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
- IRubyObject value, RubyModule typeClass) {
+ public static IRubyObject checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
+ IRubyObject value, RubyModule typeClass) {
Ruby runtime = context.runtime;
Object val;
switch(fieldType) {
@@ -106,7 +106,7 @@ public class Utils {
break;
case BYTES:
case STRING:
- validateStringEncoding(context.runtime, fieldType, value);
+ value = validateStringEncoding(context, fieldType, value);
break;
case MESSAGE:
if (value.getMetaClass() != typeClass) {
@@ -127,6 +127,7 @@ public class Utils {
default:
break;
}
+ return value;
}
public static IRubyObject wrapPrimaryValue(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, Object value) {
@@ -148,10 +149,16 @@ public class Utils {
return runtime.newFloat((Double) value);
case BOOL:
return (Boolean) value ? runtime.getTrue() : runtime.getFalse();
- case BYTES:
- return runtime.newString(((ByteString) value).toStringUtf8());
- case STRING:
- return runtime.newString(value.toString());
+ case BYTES: {
+ IRubyObject wrapped = runtime.newString(((ByteString) value).toStringUtf8());
+ wrapped.setFrozen(true);
+ return wrapped;
+ }
+ case STRING: {
+ IRubyObject wrapped = runtime.newString(value.toString());
+ wrapped.setFrozen(true);
+ return wrapped;
+ }
default:
return runtime.getNil();
}
@@ -180,25 +187,21 @@ public class Utils {
}
}
- public static void validateStringEncoding(Ruby runtime, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
+ public static IRubyObject validateStringEncoding(ThreadContext context, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
if (!(value instanceof RubyString))
- throw runtime.newTypeError("Invalid argument for string field.");
- Encoding encoding = ((RubyString) value).getEncoding();
+ throw context.runtime.newTypeError("Invalid argument for string field.");
switch(type) {
case BYTES:
- if (encoding != ASCIIEncoding.INSTANCE)
- throw runtime.newTypeError("Encoding for bytes fields" +
- " must be \"ASCII-8BIT\", but was " + encoding);
+ value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::ASCII_8BIT"));
break;
case STRING:
- if (encoding != UTF8Encoding.INSTANCE
- && encoding != USASCIIEncoding.INSTANCE)
- throw runtime.newTypeError("Encoding for string fields" +
- " must be \"UTF-8\" or \"ASCII\", but was " + encoding);
+ value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::UTF_8"));
break;
default:
break;
}
+ value.setFrozen(true);
+ return value;
}
public static void checkNameAvailability(ThreadContext context, String name) {