prevent race condition when fetching device ids

This commit is contained in:
Daniel Gultsch 2018-10-03 22:03:40 +02:00
parent f608fb349a
commit 23282484d6
2 changed files with 47 additions and 43 deletions

View File

@ -1021,28 +1021,33 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
} }
if (packet != null) { if (packet != null) {
mXmppConnectionService.sendIqPacket(account, packet, (account, response) -> { mXmppConnectionService.sendIqPacket(account, packet, (account, response) -> {
synchronized (fetchDeviceIdsMap) { if (response.getType() == IqPacket.TYPE.RESULT) {
List<OnDeviceIdsFetched> callbacks = fetchDeviceIdsMap.remove(jid); fetchDeviceListStatus.put(jid, true);
if (response.getType() == IqPacket.TYPE.RESULT) { Element item = mXmppConnectionService.getIqParser().getItem(response);
fetchDeviceListStatus.put(jid, true); Set<Integer> deviceIds = mXmppConnectionService.getIqParser().deviceIds(item);
Element item = mXmppConnectionService.getIqParser().getItem(response); registerDevices(jid, deviceIds);
Set<Integer> deviceIds = mXmppConnectionService.getIqParser().deviceIds(item); final List<OnDeviceIdsFetched> callbacks;
registerDevices(jid, deviceIds); synchronized (fetchDeviceIdsMap) {
if (callbacks != null) { callbacks = fetchDeviceIdsMap.remove(jid);
for (OnDeviceIdsFetched c : callbacks) { }
c.fetched(jid, deviceIds); if (callbacks != null) {
} for (OnDeviceIdsFetched c : callbacks) {
c.fetched(jid, deviceIds);
} }
}
} else {
if (response.getType() == IqPacket.TYPE.TIMEOUT) {
fetchDeviceListStatus.remove(jid);
} else { } else {
if (response.getType() == IqPacket.TYPE.TIMEOUT) { fetchDeviceListStatus.put(jid, false);
fetchDeviceListStatus.remove(jid); }
} else { final List<OnDeviceIdsFetched> callbacks;
fetchDeviceListStatus.put(jid, false); synchronized (fetchDeviceIdsMap) {
} callbacks = fetchDeviceIdsMap.remove(jid);
if (callbacks != null) { }
for (OnDeviceIdsFetched c : callbacks) { if (callbacks != null) {
c.fetched(jid, null); for (OnDeviceIdsFetched c : callbacks) {
} c.fetched(jid, null);
} }
} }
} }
@ -1157,8 +1162,9 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
Set<SignalProtocolAddress> addresses = new HashSet<>(); Set<SignalProtocolAddress> addresses = new HashSet<>();
for (Jid jid : getCryptoTargets(conversation)) { for (Jid jid : getCryptoTargets(conversation)) {
Log.d(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Finding devices without session for " + jid); Log.d(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Finding devices without session for " + jid);
if (deviceIds.get(jid) != null) { Set<Integer> ids = deviceIds.get(jid);
for (Integer foreignId : this.deviceIds.get(jid)) { if (deviceIds.get(jid) != null && !ids.isEmpty()) {
for (Integer foreignId : ids) {
SignalProtocolAddress address = new SignalProtocolAddress(jid.toString(), foreignId); SignalProtocolAddress address = new SignalProtocolAddress(jid.toString(), foreignId);
if (sessions.get(address) == null) { if (sessions.get(address) == null) {
IdentityKey identityKey = axolotlStore.loadSession(address).getSessionState().getRemoteIdentityKey(); IdentityKey identityKey = axolotlStore.loadSession(address).getSessionState().getRemoteIdentityKey();
@ -1181,22 +1187,21 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
Log.w(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Have no target devices in PEP!"); Log.w(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Have no target devices in PEP!");
} }
} }
if (deviceIds.get(account.getJid().asBareJid()) != null) { Set<Integer> ownIds = this.deviceIds.get(account.getJid().asBareJid());
for (Integer ownId : this.deviceIds.get(account.getJid().asBareJid())) { for (Integer ownId : (ownIds != null ? ownIds : new HashSet<Integer>())) {
SignalProtocolAddress address = new SignalProtocolAddress(account.getJid().asBareJid().toString(), ownId); SignalProtocolAddress address = new SignalProtocolAddress(account.getJid().asBareJid().toString(), ownId);
if (sessions.get(address) == null) { if (sessions.get(address) == null) {
IdentityKey identityKey = axolotlStore.loadSession(address).getSessionState().getRemoteIdentityKey(); IdentityKey identityKey = axolotlStore.loadSession(address).getSessionState().getRemoteIdentityKey();
if (identityKey != null) { if (identityKey != null) {
Log.d(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Already have session for " + address.toString() + ", adding to cache..."); Log.d(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Already have session for " + address.toString() + ", adding to cache...");
XmppAxolotlSession session = new XmppAxolotlSession(account, axolotlStore, address, identityKey); XmppAxolotlSession session = new XmppAxolotlSession(account, axolotlStore, address, identityKey);
sessions.put(address, session); sessions.put(address, session);
} else {
Log.d(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Found device " + account.getJid().asBareJid() + ":" + ownId);
if (fetchStatusMap.get(address) != FetchStatus.ERROR) {
addresses.add(address);
} else { } else {
Log.d(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Found device " + account.getJid().asBareJid() + ":" + ownId); Log.d(Config.LOGTAG, getLogprefix(account) + "skipping over " + address + " because it's broken");
if (fetchStatusMap.get(address) != FetchStatus.ERROR) {
addresses.add(address);
} else {
Log.d(Config.LOGTAG, getLogprefix(account) + "skipping over " + address + " because it's broken");
}
} }
} }
} }
@ -1215,12 +1220,7 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
} }
Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": createSessionsIfNeeded() - jids with empty device list: " + jidsWithEmptyDeviceList); Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": createSessionsIfNeeded() - jids with empty device list: " + jidsWithEmptyDeviceList);
if (jidsWithEmptyDeviceList.size() > 0) { if (jidsWithEmptyDeviceList.size() > 0) {
fetchDeviceIds(jidsWithEmptyDeviceList, new OnMultipleDeviceIdFetched() { fetchDeviceIds(jidsWithEmptyDeviceList, () -> createSessionsIfNeededActual(conversation));
@Override
public void fetched() {
createSessionsIfNeededActual(conversation);
}
});
return true; return true;
} else { } else {
return createSessionsIfNeededActual(conversation); return createSessionsIfNeededActual(conversation);

View File

@ -78,6 +78,10 @@ public class FingerprintStatus implements Comparable<FingerprintStatus> {
return status; return status;
} }
public static FingerprintStatus createActive(Boolean trusted) {
return createActive(trusted != null && trusted);
}
public static FingerprintStatus createActive(boolean trusted) { public static FingerprintStatus createActive(boolean trusted) {
final FingerprintStatus status = new FingerprintStatus(); final FingerprintStatus status = new FingerprintStatus();
status.trust = trusted ? Trust.TRUSTED : Trust.UNTRUSTED; status.trust = trusted ? Trust.TRUSTED : Trust.UNTRUSTED;