summaryrefslogtreecommitdiff
path: root/sleekxmpp/xmlstream/xmlstream.py
diff options
context:
space:
mode:
Diffstat (limited to 'sleekxmpp/xmlstream/xmlstream.py')
-rw-r--r--sleekxmpp/xmlstream/xmlstream.py30
1 files changed, 26 insertions, 4 deletions
diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py
index d12e29b3..30bbaa5e 100644
--- a/sleekxmpp/xmlstream/xmlstream.py
+++ b/sleekxmpp/xmlstream/xmlstream.py
@@ -287,6 +287,7 @@ class XMLStream(object):
self.__filters = {'in': [], 'out': [], 'out_sync': []}
self.__thread_count = 0
self.__thread_cond = threading.Condition()
+ self.__active_threads = set()
self._use_daemons = False
self._disconnect_wait_for_threads = True
@@ -1233,6 +1234,7 @@ class XMLStream(object):
return True
def _start_thread(self, name, target, track=True):
+ self.__active_threads.add(name)
self.__thread[name] = threading.Thread(name=name, target=target)
self.__thread[name].daemon = self._use_daemons
self.__thread[name].start()
@@ -1241,11 +1243,28 @@ class XMLStream(object):
with self.__thread_cond:
self.__thread_count += 1
- def _end_thread(self, name):
+ def _end_thread(self, name, early=False):
with self.__thread_cond:
- self.__thread_count -= 1
- log.debug("Stopped %s thread. %s threads remain." % (
- name, self.__thread_count))
+ curr_thread = threading.current_thread().name
+ if curr_thread in self.__active_threads:
+ self.__thread_count -= 1
+ self.__active_threads.remove(curr_thread)
+
+ if early:
+ log.debug('Threading deadlock prevention!')
+ log.debug(("Marked %s thread as ended due to " + \
+ "disconnect() call. %s threads remain.") % (
+ name, self.__thread_count))
+ else:
+ log.debug("Stopped %s thread. %s threads remain." % (
+ name, self.__thread_count))
+
+ else:
+ log.debug(("Finished exiting %s thread after early " + \
+ "termination from disconnect() call. " + \
+ "%s threads remain.") % (
+ name, self.__thread_count))
+
if self.__thread_count == 0:
self.__thread_cond.notify()
@@ -1254,6 +1273,9 @@ class XMLStream(object):
if self.__thread_count != 0:
log.debug("Waiting for %s threads to exit." %
self.__thread_count)
+ name = threading.current_thread().name
+ if name in self.__thread:
+ self._end_thread(name, early=True)
self.__thread_cond.wait(4)
if self.__thread_count != 0:
log.error("Hanged threads: %s" % threading.enumerate())