diff options
Diffstat (limited to 'sleekxmpp')
-rw-r--r-- | sleekxmpp/xmlstream/xmlstream.py | 30 |
1 files changed, 26 insertions, 4 deletions
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index d12e29b3..30bbaa5e 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -287,6 +287,7 @@ class XMLStream(object): self.__filters = {'in': [], 'out': [], 'out_sync': []} self.__thread_count = 0 self.__thread_cond = threading.Condition() + self.__active_threads = set() self._use_daemons = False self._disconnect_wait_for_threads = True @@ -1233,6 +1234,7 @@ class XMLStream(object): return True def _start_thread(self, name, target, track=True): + self.__active_threads.add(name) self.__thread[name] = threading.Thread(name=name, target=target) self.__thread[name].daemon = self._use_daemons self.__thread[name].start() @@ -1241,11 +1243,28 @@ class XMLStream(object): with self.__thread_cond: self.__thread_count += 1 - def _end_thread(self, name): + def _end_thread(self, name, early=False): with self.__thread_cond: - self.__thread_count -= 1 - log.debug("Stopped %s thread. %s threads remain." % ( - name, self.__thread_count)) + curr_thread = threading.current_thread().name + if curr_thread in self.__active_threads: + self.__thread_count -= 1 + self.__active_threads.remove(curr_thread) + + if early: + log.debug('Threading deadlock prevention!') + log.debug(("Marked %s thread as ended due to " + \ + "disconnect() call. %s threads remain.") % ( + name, self.__thread_count)) + else: + log.debug("Stopped %s thread. %s threads remain." % ( + name, self.__thread_count)) + + else: + log.debug(("Finished exiting %s thread after early " + \ + "termination from disconnect() call. " + \ + "%s threads remain.") % ( + name, self.__thread_count)) + if self.__thread_count == 0: self.__thread_cond.notify() @@ -1254,6 +1273,9 @@ class XMLStream(object): if self.__thread_count != 0: log.debug("Waiting for %s threads to exit." % self.__thread_count) + name = threading.current_thread().name + if name in self.__thread: + self._end_thread(name, early=True) self.__thread_cond.wait(4) if self.__thread_count != 0: log.error("Hanged threads: %s" % threading.enumerate()) |