diff --git a/root/root.cpp b/root/root.cpp index 9d00e5ac..6928af89 100644 --- a/root/root.cpp +++ b/root/root.cpp @@ -97,6 +97,7 @@ #include #include #include +#include #include #include "geoip-html.h" @@ -195,6 +196,7 @@ static Meter s_forwardRate; static Meter s_discardedForwardRate; static std::string s_planet; +static std::list< SharedPtr > s_peers; static std::unordered_map< uint64_t,std::unordered_map< MulticastGroup,std::unordered_map< Address,int64_t,AddressHasher >,MulticastGroupHasher > > s_multicastSubscriptions; static std::unordered_map< Identity,SharedPtr,IdentityHasher > s_peersByIdentity; static std::unordered_map< Address,std::set< SharedPtr >,AddressHasher > s_peersByVirtAddr; @@ -205,7 +207,7 @@ static std::map< std::pair< uint32_t,uint32_t >,std::pair< float,float > > s_geo static std::map< std::pair< std::array< uint64_t,2 >,std::array< uint64_t,2 > >,std::pair< float,float > > s_geoIp6; static std::mutex s_planet_l; -static std::mutex s_siblings_l; +static std::mutex s_peers_l; static std::mutex s_multicastSubscriptions_l; static std::mutex s_peersByIdentity_l; static std::mutex s_peersByVirtAddr_l; @@ -246,12 +248,10 @@ static void handlePacket(const int v4s,const int v6s,const InetAddress *const ip // If this is an un-encrypted HELLO, either learn a new peer or verify // that this is a peer we already know. if ((pkt.cipher() == ZT_PROTO_CIPHER_SUITE__POLY1305_NONE)&&(pkt.verb() == Packet::VERB_HELLO)) { - std::lock_guard pbi_l(s_peersByIdentity_l); - std::lock_guard pbv_l(s_peersByVirtAddr_l); - Identity id; if (id.deserialize(pkt,ZT_PROTO_VERB_HELLO_IDX_IDENTITY)) { { + std::lock_guard pbi_l(s_peersByIdentity_l); auto pById = s_peersByIdentity.find(id); if (pById != s_peersByIdentity.end()) { peer = pById->second; @@ -273,8 +273,14 @@ static void handlePacket(const int v4s,const int v6s,const InetAddress *const ip } peer->id = id; peer->lastReceive = now; - s_peersByIdentity.emplace(id,peer); - s_peersByVirtAddr[id.address()].emplace(peer); + std::lock_guard pl(s_peers_l); + std::lock_guard pbi_l(s_peersByIdentity_l); + std::lock_guard pbv_l(s_peersByVirtAddr_l); + if (s_peersByIdentity.find(id) == s_peersByIdentity.end()) { // double check to ensure another thread didn't add this + s_peers.emplace_back(peer); + s_peersByIdentity.emplace(id,peer); + s_peersByVirtAddr[id.address()].emplace(peer); + } } else { printf("%s HELLO rejected: packet authentication failed" ZT_EOL_S,ip->toString(ipstr)); return; @@ -1117,37 +1123,41 @@ int main(int argc,char **argv) // Remove expired peers { - std::lock_guard pbi_l(s_peersByIdentity_l); - for(auto p=s_peersByIdentity.begin();p!=s_peersByIdentity.end();) { - if ((now - p->second->lastReceive) > ZT_PEER_ACTIVITY_TIMEOUT) { + std::lock_guard pbi_l(s_peers_l); + for(auto p=s_peers.begin();p!=s_peers.end();) { + if ((now - (*p)->lastReceive) > ZT_PEER_ACTIVITY_TIMEOUT) { + std::lock_guard pbi_l(s_peersByIdentity_l); std::lock_guard pbv_l(s_peersByVirtAddr_l); std::lock_guard pbp_l(s_peersByPhysAddr_l); - auto pbv = s_peersByVirtAddr.find(p->second->id.address()); + s_peersByIdentity.erase((*p)->id); + + auto pbv = s_peersByVirtAddr.find((*p)->id.address()); if (pbv != s_peersByVirtAddr.end()) { - pbv->second.erase(p->second); + pbv->second.erase((*p)); if (pbv->second.empty()) s_peersByVirtAddr.erase(pbv); } - if (p->second->ip4) { - auto pbp = s_peersByPhysAddr.find(p->second->ip4); + if ((*p)->ip4) { + auto pbp = s_peersByPhysAddr.find((*p)->ip4); if (pbp != s_peersByPhysAddr.end()) { - pbp->second.erase(p->second); - if (pbp->second.empty()) - s_peersByPhysAddr.erase(pbp); - } - } - if (p->second->ip6) { - auto pbp = s_peersByPhysAddr.find(p->second->ip6); - if (pbp != s_peersByPhysAddr.end()) { - pbp->second.erase(p->second); + pbp->second.erase((*p)); if (pbp->second.empty()) s_peersByPhysAddr.erase(pbp); } } - s_peersByIdentity.erase(p++); + if ((*p)->ip6) { + auto pbp = s_peersByPhysAddr.find((*p)->ip6); + if (pbp != s_peersByPhysAddr.end()) { + pbp->second.erase((*p)); + if (pbp->second.empty()) + s_peersByPhysAddr.erase(pbp); + } + } + + s_peers.erase(p++); } else ++p; } }