Skip to content

Commit

Permalink
Add equals and hashCode to generate classes
Browse files Browse the repository at this point in the history
Fixes #377

(cherry picked from commit 22ca4c8)
  • Loading branch information
acogoluegnes committed Aug 16, 2018
1 parent 6fecb1c commit de907f0
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 1 deletion.
48 changes: 48 additions & 0 deletions codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@ def printGetter(fieldType, fieldName):
print(" public int getClassId() { return %i; }" % (c.index))
print(" public String getClassName() { return \"%s\"; }" % (c.name))

if c.fields:
equalsHashCode(spec, c.fields, java_class_name(c.name), 'Properties', False)

printPropertiesBuilder(c)

#accessor methods
Expand Down Expand Up @@ -400,6 +403,49 @@ def printPropertiesClasses():

#--------------------------------------------------------------------------------

def equalsHashCode(spec, fields, jClassName, classSuffix, usePrimitiveType):
print()
print()
print(" @Override")
print(" public boolean equals(Object o) {")
print(" if (this == o)")
print(" return true;")
print(" if (o == null || getClass() != o.getClass())")
print(" return false;")
print(" %s%s that = (%s%s) o;" % (jClassName, classSuffix, jClassName, classSuffix))

for f in fields:
(fType, fName) = (java_field_type(spec, f.domain), java_field_name(f.name))
if usePrimitiveType and fType in javaScalarTypes:
print(" if (%s != that.%s)" % (fName, fName))
else:
print(" if (%s != null ? !%s.equals(that.%s) : that.%s != null)" % (fName, fName, fName, fName))

print(" return false;")

print(" return true;")
print(" }")

print()
print(" @Override")
print(" public int hashCode() {")
print(" int result = 0;")

for f in fields:
(fType, fName) = (java_field_type(spec, f.domain), java_field_name(f.name))
if usePrimitiveType and fType in javaScalarTypes:
if fType == 'boolean':
print(" result = 31 * result + (%s ? 1 : 0);" % fName)
elif fType == 'long':
print(" result = 31 * result + (int) (%s ^ (%s >>> 32));" % (fName, fName))
else:
print(" result = 31 * result + %s;" % fName)
else:
print(" result = 31 * result + (%s != null ? %s.hashCode() : 0);" % (fName, fName))

print(" return result;")
print(" }")

def genJavaImpl(spec):
def printHeader():
printFileHeader()
Expand Down Expand Up @@ -503,6 +549,8 @@ def write_arguments():
getters()
constructors()
others()
if m.arguments:
equalsHashCode(spec, m.arguments, java_class_name(m.name), '', True)

argument_debug_string()
write_arguments()
Expand Down
3 changes: 2 additions & 1 deletion src/test/java/com/rabbitmq/client/test/ClientTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
JacksonJsonRpcTest.class,
AddressTest.class,
DefaultRetryHandlerTest.class,
NioDeadlockOnConnectionClosing.class
NioDeadlockOnConnectionClosing.class,
GeneratedClassesTest.class
})
public class ClientTests {

Expand Down
135 changes: 135 additions & 0 deletions src/test/java/com/rabbitmq/client/test/GeneratedClassesTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright (c) 2018 Pivotal Software, Inc. All rights reserved.
//
// This software, the RabbitMQ Java client library, is triple-licensed under the
// Mozilla Public License 1.1 ("MPL"), the GNU General Public License version 2
// ("GPL") and the Apache License version 2 ("ASL"). For the MPL, please see
// LICENSE-MPL-RabbitMQ. For the GPL, please see LICENSE-GPL2. For the ASL,
// please see LICENSE-APACHE2.
//
// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND,
// either express or implied. See the LICENSE file for specific language governing
// rights and limitations of this software.
//
// If you have any questions regarding licensing, please contact us at
// info@rabbitmq.com.

package com.rabbitmq.client.test;

import com.rabbitmq.client.AMQP;
import com.rabbitmq.client.impl.AMQImpl;
import org.junit.Test;

import java.util.Calendar;
import java.util.Date;

import static java.util.Collections.singletonMap;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;

/**
*
*/
public class GeneratedClassesTest {

@Test
public void amqpPropertiesEqualsHashCode() {
checkEquals(
new AMQP.BasicProperties.Builder().correlationId("one").build(),
new AMQP.BasicProperties.Builder().correlationId("one").build()
);
checkNotEquals(
new AMQP.BasicProperties.Builder().correlationId("one").build(),
new AMQP.BasicProperties.Builder().correlationId("two").build()
);
Date date = Calendar.getInstance().getTime();
checkEquals(
new AMQP.BasicProperties.Builder()
.deliveryMode(1)
.headers(singletonMap("one", "two"))
.correlationId("123")
.expiration("later")
.priority(10)
.replyTo("me")
.contentType("text/plain")
.contentEncoding("UTF-8")
.userId("jdoe")
.appId("app1")
.clusterId("cluster1")
.messageId("message123")
.timestamp(date)
.type("type")
.build(),
new AMQP.BasicProperties.Builder()
.deliveryMode(1)
.headers(singletonMap("one", "two"))
.correlationId("123")
.expiration("later")
.priority(10)
.replyTo("me")
.contentType("text/plain")
.contentEncoding("UTF-8")
.userId("jdoe")
.appId("app1")
.clusterId("cluster1")
.messageId("message123")
.timestamp(date)
.type("type")
.build()
);
checkNotEquals(
new AMQP.BasicProperties.Builder()
.deliveryMode(1)
.headers(singletonMap("one", "two"))
.correlationId("123")
.expiration("later")
.priority(10)
.replyTo("me")
.contentType("text/plain")
.contentEncoding("UTF-8")
.userId("jdoe")
.appId("app1")
.clusterId("cluster1")
.messageId("message123")
.timestamp(date)
.type("type")
.build(),
new AMQP.BasicProperties.Builder()
.deliveryMode(2)
.headers(singletonMap("one", "two"))
.correlationId("123")
.expiration("later")
.priority(10)
.replyTo("me")
.contentType("text/plain")
.contentEncoding("UTF-8")
.userId("jdoe")
.appId("app1")
.clusterId("cluster1")
.messageId("message123")
.timestamp(date)
.type("type")
.build()
);

}

@Test public void amqImplEqualsHashCode() {
checkEquals(
new AMQImpl.Basic.Deliver("tag", 1L, false, "amq.direct","rk"),
new AMQImpl.Basic.Deliver("tag", 1L, false, "amq.direct","rk")
);
checkNotEquals(
new AMQImpl.Basic.Deliver("tag", 1L, false, "amq.direct","rk"),
new AMQImpl.Basic.Deliver("tag", 2L, false, "amq.direct","rk")
);
}

private void checkEquals(Object o1, Object o2) {
assertEquals(o1, o2);
assertEquals(o1.hashCode(), o2.hashCode());
}

private void checkNotEquals(Object o1, Object o2) {
assertNotEquals(o1, o2);
}
}

0 comments on commit de907f0

Please sign in to comment.