Simplify cache and extend cache operations.

Remove the external cache dependency and use a simple LRU based on
LinkedHashMap.
Make it possible to get the parse time of DNSMessage, which means we
can evaluate the TTL later on :-)
This commit is contained in:
Rene Treffer 2014-06-12 22:39:51 +02:00
parent 7dd8cfc6e6
commit 4da60e7e20
3 changed files with 92 additions and 16 deletions

View File

@ -27,13 +27,6 @@ if (isSnapshot) {
repositories { repositories {
mavenLocal() mavenLocal()
mavenCentral() mavenCentral()
maven {
url 'https://oss.sonatype.org/content/repositories/snapshots/'
}
}
dependencies {
compile 'org.igniterealtime.jxmpp:jxmpp-util-cache:0.1.0-alpha1-SNAPSHOT'
} }
jar { jar {

View File

@ -13,12 +13,12 @@ import java.security.SecureRandom;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map.Entry;
import java.util.Random; import java.util.Random;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
import org.jxmpp.util.cache.ExpirationCache;
import de.measite.minidns.Record.CLASS; import de.measite.minidns.Record.CLASS;
import de.measite.minidns.Record.TYPE; import de.measite.minidns.Record.TYPE;
@ -30,9 +30,6 @@ public class Client {
private static final Logger LOGGER = Logger.getLogger(Client.class.getName()); private static final Logger LOGGER = Logger.getLogger(Client.class.getName());
protected static final ExpirationCache<Question, DNSMessage> cache = new ExpirationCache<Question, DNSMessage>(
10, 1000 * 60 * 60 * 24);
/** /**
* The internal random class for sequence generation. * The internal random class for sequence generation.
*/ */
@ -48,6 +45,16 @@ public class Client {
*/ */
protected int timeout = 5000; protected int timeout = 5000;
/**
* The internal DNS cache.
*/
protected LinkedHashMap<Question, DNSMessage> cache;
/**
* Maximum acceptable ttl.
*/
protected long maxTTL = 60 * 60 * 1000;
/** /**
* Create a new DNS client. * Create a new DNS client.
*/ */
@ -57,6 +64,7 @@ public class Client {
} catch (NoSuchAlgorithmException e1) { } catch (NoSuchAlgorithmException e1) {
random = new SecureRandom(); random = new SecureRandom();
} }
setCacheSize(10);
} }
/** /**
@ -123,10 +131,21 @@ public class Client {
* @throws IOException On IOErrors. * @throws IOException On IOErrors.
*/ */
public DNSMessage query(Question q, String host, int port) throws IOException { public DNSMessage query(Question q, String host, int port) throws IOException {
DNSMessage dnsMessage = cache.get(q); DNSMessage dnsMessage = (cache == null) ? null : cache.get(q);
if (dnsMessage != null) { if (dnsMessage != null && dnsMessage.getReceiveTimestamp() > 0l) {
// check the ttl
long ttl = maxTTL;
for (Record r : dnsMessage.getAnswers()) {
ttl = Math.min(ttl, r.ttl);
}
for (Record r : dnsMessage.getAdditionalResourceRecords()) {
ttl = Math.min(ttl, r.ttl);
}
if (dnsMessage.getReceiveTimestamp() + ttl <
System.currentTimeMillis()) {
return dnsMessage; return dnsMessage;
} }
}
DNSMessage message = new DNSMessage(); DNSMessage message = new DNSMessage();
message.setQuestions(new Question[]{q}); message.setQuestions(new Question[]{q});
message.setRecursionDesired(true); message.setRecursionDesired(true);
@ -145,7 +164,9 @@ public class Client {
} }
for (Record record : dnsMessage.getAnswers()) { for (Record record : dnsMessage.getAnswers()) {
if (record.isAnswer(q)) { if (record.isAnswer(q)) {
cache.put(q, dnsMessage, record.ttl); if (cache != null) {
cache.put(q, dnsMessage);
}
break; break;
} }
} }
@ -305,4 +326,52 @@ public class Client {
return null; return null;
} }
/**
* Configure the cache size (default 10).
* @param maximumSize The new cache size or 0 to disable.
*/
@SuppressWarnings("serial")
public void setCacheSize(final int maximumSize) {
if (maximumSize == 0) {
this.cache = null;
} else {
LinkedHashMap<Question,DNSMessage> old = cache;
cache = new LinkedHashMap<Question,DNSMessage>() {
@Override
protected boolean removeEldestEntry(
Entry<Question, DNSMessage> eldest) {
return size() > maximumSize;
}
};
if (old != null) {
cache.putAll(old);
}
}
}
/**
* Flush the DNS cache.
*/
public void flushCache() {
if (cache != null) {
cache.clear();
}
}
/**
* Get the current maximum record ttl.
* @return The maximum record ttl.
*/
public long getMaxTTL() {
return maxTTL;
}
/**
* Set the maximum record ttl.
* @param maxTTL The new maximum ttl.
*/
public void setMaxTTL(long maxTTL) {
this.maxTTL = maxTTL;
}
} }

View File

@ -195,6 +195,11 @@ public class DNSMessage {
*/ */
protected Record additionalResourceRecords[]; protected Record additionalResourceRecords[];
/**
* The receive timestamp of this message.
*/
protected long receiveTimestamp;
/** /**
* Retrieve the current DNS message id. * Retrieve the current DNS message id.
* @return The current DNS message id. * @return The current DNS message id.
@ -211,6 +216,14 @@ public class DNSMessage {
this.id = id & 0xffff; this.id = id & 0xffff;
} }
/**
* Get the receive timestamp if this message was created via parse.
* This should be used to evaluate TTLs.
*/
public long getReceiveTimestamp() {
return receiveTimestamp;
}
/** /**
* Retrieve the query type (true or false; * Retrieve the query type (true or false;
* @return True if this DNS message is a query. * @return True if this DNS message is a query.
@ -410,6 +423,7 @@ public class DNSMessage {
message.authenticData = ((header >> 5) & 1) == 1; message.authenticData = ((header >> 5) & 1) == 1;
message.checkDisabled = ((header >> 4) & 1) == 1; message.checkDisabled = ((header >> 4) & 1) == 1;
message.responseCode = RESPONSE_CODE.getResponseCode(header & 0xf); message.responseCode = RESPONSE_CODE.getResponseCode(header & 0xf);
message.receiveTimestamp = System.currentTimeMillis();
int questionCount = dis.readUnsignedShort(); int questionCount = dis.readUnsignedShort();
int answerCount = dis.readUnsignedShort(); int answerCount = dis.readUnsignedShort();
int nameserverCount = dis.readUnsignedShort(); int nameserverCount = dis.readUnsignedShort();