diff options
-rw-r--r-- | sleekxmpp/basexmpp.py | 3 | ||||
-rw-r--r-- | sleekxmpp/clientxmpp.py | 355 | ||||
-rw-r--r-- | sleekxmpp/stanza/__init__.py | 4 | ||||
-rw-r--r-- | sleekxmpp/stanza/bind.py | 26 | ||||
-rw-r--r-- | sleekxmpp/stanza/sasl.py | 104 | ||||
-rw-r--r-- | sleekxmpp/stanza/session.py | 25 | ||||
-rw-r--r-- | sleekxmpp/stanza/stream_features.py | 52 | ||||
-rw-r--r-- | sleekxmpp/stanza/tls.py | 50 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/stanzabase.py | 7 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/tostring/tostring.py | 22 | ||||
-rw-r--r-- | sleekxmpp/xmlstream/tostring/tostring26.py | 22 | ||||
-rw-r--r-- | tests/test_tostring.py | 4 |
12 files changed, 512 insertions, 162 deletions
diff --git a/sleekxmpp/basexmpp.py b/sleekxmpp/basexmpp.py index 8347bfe0..e2865d39 100644 --- a/sleekxmpp/basexmpp.py +++ b/sleekxmpp/basexmpp.py @@ -92,6 +92,7 @@ class BaseXMPP(XMLStream): # Deprecated method names are re-mapped for backwards compatibility. self.default_ns = default_ns self.stream_ns = 'http://etherx.jabber.org/streams' + self.namespace_map[self.stream_ns] = 'stream' self.boundjid = JID("") @@ -105,6 +106,8 @@ class BaseXMPP(XMLStream): self.sentpresence = False + self.stanza = sleekxmpp.stanza + self.register_handler( Callback('IM', MatchXPath('{%s}message/{%s}body' % (self.default_ns, diff --git a/sleekxmpp/clientxmpp.py b/sleekxmpp/clientxmpp.py index c518a4ce..dc08522d 100644 --- a/sleekxmpp/clientxmpp.py +++ b/sleekxmpp/clientxmpp.py @@ -18,9 +18,11 @@ import threading from sleekxmpp import plugins from sleekxmpp import stanza from sleekxmpp.basexmpp import BaseXMPP -from sleekxmpp.stanza import Message, Presence, Iq +from sleekxmpp.stanza import * +from sleekxmpp.stanza import tls +from sleekxmpp.stanza import sasl from sleekxmpp.xmlstream import XMLStream, RestartStream -from sleekxmpp.xmlstream import StanzaBase, ET +from sleekxmpp.xmlstream import StanzaBase, ET, register_stanza_plugin from sleekxmpp.xmlstream.matcher import * from sleekxmpp.xmlstream.handler import * @@ -85,14 +87,24 @@ class ClientXMPP(BaseXMPP): self.stream_footer = "</stream:stream>" self.features = [] - self.registered_features = [] + self._stream_feature_handlers = {} + self._stream_feature_order = [] + self._sasl_mechanism_handlers = {} + self._sasl_mechanism_priorities = [] #TODO: Use stream state here self.authenticated = False self.sessionstarted = False self.bound = False self.bindfail = False - self.add_event_handler('connected', self.handle_connected) + + self.add_event_handler('connected', self._handle_connected) + + self.register_stanza(StreamFeatures) + self.register_stanza(tls.Proceed) + self.register_stanza(sasl.Success) + self.register_stanza(sasl.Failure) + self.register_stanza(sasl.Auth) self.register_handler( Callback('Stream Features', @@ -105,32 +117,25 @@ class ClientXMPP(BaseXMPP): 'jabber:iq:roster')), self._handle_roster)) - self.register_feature( - "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls' />", - self._handle_starttls, True) - self.register_feature( - "<mechanisms xmlns='urn:ietf:params:xml:ns:xmpp-sasl' />", - self._handle_sasl_auth, True) - self.register_feature( - "<bind xmlns='urn:ietf:params:xml:ns:xmpp-bind' />", - self._handle_bind_resource) - self.register_feature( - "<session xmlns='urn:ietf:params:xml:ns:xmpp-session' />", - self._handle_start_session) - - def handle_connected(self, event=None): - #TODO: Use stream state here - self.authenticated = False - self.sessionstarted = False - self.bound = False - self.bindfail = False - self.schedule("session timeout checker", 15, - self._session_timeout_check) - - def _session_timeout_check(self): - if not self.session_started_event.isSet(): - log.debug("Session start has taken more than 15 seconds") - self.disconnect(reconnect=self.auto_reconnect) + self.register_feature('starttls', self._handle_starttls, + restart=True, + order=0) + self.register_feature('mechanisms', self._handle_sasl_auth, + restart=True, + order=100) + self.register_feature('bind', self._handle_bind_resource, + restart=False, + order=10000) + self.register_feature('session', self._handle_start_session, + restart=False, + order=10001) + + self.register_sasl_mechanism('PLAIN', + self._handle_sasl_plain, + priority=1) + self.register_sasl_mechanism('ANONYMOUS', + self._handle_sasl_plain, + priority=0) def connect(self, address=tuple(), reattempt=True, use_tls=True): """ @@ -192,19 +197,54 @@ class ClientXMPP(BaseXMPP): return XMLStream.connect(self, address[0], address[1], use_tls=use_tls, reattempt=reattempt) - def register_feature(self, mask, pointer, breaker=False): + def register_feature(self, name, handler, restart=False, order=5000): """ Register a stream feature. Arguments: - mask -- An XML string matching the feature's element. - pointer -- The function to execute if the feature is received. - breaker -- Indicates if feature processing should halt with + name -- The name of the stream feature. + handler -- The function to execute if the feature is received. + restart -- Indicates if feature processing should halt with this feature. Defaults to False. + order -- The relative ordering in which the feature should + be negotiated. Lower values will be attempted + earlier when available. + """ + self._stream_feature_handlers[name] = (handler, restart) + self._stream_feature_order.append((order, name)) + self._stream_feature_order.sort() + + def register_sasl_mechanism(self, name, handler, priority=0): + """ + Register a handler for a SASL authentication mechanism. + + Arguments: + name -- The name of the mechanism (all caps) + handler -- The function that will perform the + authentication. The function must + return True if it is able to carry + out the authentication, False if + a required condition is not met. + priority -- An integer value indicating the + preferred ordering for the mechanism. + High values will be attempted first. + """ + self._sasl_mechanism_handlers[name] = handler + self._sasl_mechanism_priorities.append((priority, name)) + self._sasl_mechanism_priorities.sort(reverse=True) + + def remove_sasl_mechanism(self, name): + """ + Remove support for a given SASL authentication mechanism. + + Arguments: + name -- The name of the mechanism to remove (all caps) """ - self.registered_features.append((MatchXMLMask(mask), - pointer, - breaker)) + if name in self._sasl_mechanism_handlers: + del self._sasl_mechanism_handlers[name] + + p = self._sasl_mechanism_priorities + self._sasl_mechanism_priorities = [i for i in p if i[1] != name] def update_roster(self, jid, name=None, subscription=None, groups=[], block=True, timeout=None, callback=None): @@ -276,6 +316,21 @@ class ClientXMPP(BaseXMPP): else: return self._handle_roster(response, request=True) + def _handle_connected(self, event=None): + #TODO: Use stream state here + self.authenticated = False + self.sessionstarted = False + self.bound = False + self.bindfail = False + self.features = [] + + def session_timeout(): + if not self.session_started_event.isSet(): + log.debug("Session start has taken more than 15 seconds") + self.disconnect(reconnect=self.auto_reconnect) + + self.schedule("session timeout checker", 15, session_timeout) + def _handle_stream_features(self, features): """ Process the received stream features. @@ -283,170 +338,176 @@ class ClientXMPP(BaseXMPP): Arguments: features -- The features stanza. """ - # Record all of the features. - self.features = [] - for sub in features.xml: - self.features.append(sub.tag) - - # Process the features. - for sub in features.xml: - for feature in self.registered_features: - mask, handler, halt = feature - if mask.match(sub): - if handler(sub) and halt: - # Don't continue if the feature was - # marked as a breaker. - return True - - def _handle_starttls(self, xml): + for order, name in self._stream_feature_order: + if name in features['features']: + handler, restart = self._stream_feature_handlers[name] + if handler(features) and restart: + # Don't continue if the feature requires + # restarting the XML stream. + return True + + def _handle_starttls(self, features): """ Handle notification that the server supports TLS. Arguments: - xml -- The STARTLS proceed element. + features -- The stream:features element. """ + + def tls_proceed(proceed): + """Restart the XML stream when TLS is accepted.""" + log.debug("Starting TLS") + if self.start_tls(): + self.features.append('starttls') + raise RestartStream() + if not self.use_tls: return False - elif not self.authenticated and self.ssl_support: - tls_ns = 'urn:ietf:params:xml:ns:xmpp-tls' - self.add_handler("<proceed xmlns='%s' />" % tls_ns, - self._handle_tls_start, - name='TLS Proceed', - instream=True) - self.send_xml(xml) + elif self.ssl_support: + self.register_handler( + Callback('STARTTLS Proceed', + MatchXPath(tls.Proceed.tag_name()), + tls_proceed, + instream=True)) + self.send(features['starttls']) return True else: log.warning("The module tlslite is required to log in" +\ " to some servers, and has not been found.") return False - def _handle_tls_start(self, xml): + def _handle_sasl_auth(self, features): """ - Handle encrypting the stream using TLS. + Handle authenticating using SASL. - Restarts the stream. + Arguments: + features -- The stream features stanza. """ - log.debug("Starting TLS") - if self.start_tls(): + + def sasl_success(stanza): + """SASL authentication succeeded. Restart the stream.""" + self.authenticated = True + self.features.append('mechanisms') raise RestartStream() - def _handle_sasl_auth(self, xml): + def sasl_fail(stanza): + """SASL authentication failed. Disconnect and shutdown.""" + log.info("Authentication failed.") + self.event("failed_auth", direct=True) + self.disconnect() + log.debug("Starting SASL Auth") + return True + + self.register_handler( + Callback('SASL Success', + MatchXPath(sasl.Success.tag_name()), + sasl_success, + instream=True, + once=True)) + + self.register_handler( + Callback('SASL Failure', + MatchXPath(sasl.Failure.tag_name()), + sasl_fail, + instream=True, + once=True)) + + for priority, mech in self._sasl_mechanism_priorities: + if mech in self._sasl_mechanism_handlers: + handler = self._sasl_mechanism_handlers[mech] + if handler(self): + break + else: + log.error("No appropriate login method.") + self.disconnect() + + return True + + def _handle_sasl_plain(self, xmpp): """ - Handle authenticating using SASL. + Attempt to authenticate using SASL PLAIN. Arguments: - xml -- The SASL mechanisms stanza. + xmpp -- The SleekXMPP connection instance. """ - if self.use_tls and \ - '{urn:ietf:params:xml:ns:xmpp-tls}starttls' in self.features: + if not xmpp.boundjid.user: return False - log.debug("Starting SASL Auth") - sasl_ns = 'urn:ietf:params:xml:ns:xmpp-sasl' - self.add_handler("<success xmlns='%s' />" % sasl_ns, - self._handle_auth_success, - name='SASL Sucess', - instream=True) - self.add_handler("<failure xmlns='%s' />" % sasl_ns, - self._handle_auth_fail, - name='SASL Failure', - instream=True) - - sasl_mechs = xml.findall('{%s}mechanism' % sasl_ns) - if sasl_mechs: - for sasl_mech in sasl_mechs: - self.features.append("sasl:%s" % sasl_mech.text) - if 'sasl:PLAIN' in self.features and self.boundjid.user: - if sys.version_info < (3, 0): - user = bytes(self.boundjid.user) - password = bytes(self.password) - else: - user = bytes(self.boundjid.user, 'utf-8') - password = bytes(self.password, 'utf-8') - - auth = base64.b64encode(b'\x00' + user + \ - b'\x00' + password).decode('utf-8') - - self.send("<auth xmlns='%s' mechanism='PLAIN'>%s</auth>" % ( - sasl_ns, - auth)) - elif 'sasl:ANONYMOUS' in self.features and not self.boundjid.user: - self.send("<auth xmlns='%s' mechanism='%s' />" % ( - sasl_ns, - 'ANONYMOUS')) - else: - log.error("No appropriate login method.") - self.disconnect() + if sys.version_info < (3, 0): + user = bytes(self.boundjid.user) + password = bytes(self.password) + else: + user = bytes(self.boundjid.user, 'utf-8') + password = bytes(self.password, 'utf-8') + + auth = base64.b64encode(b'\x00' + user + \ + b'\x00' + password).decode('utf-8') + + resp = sasl.Auth(xmpp) + resp['mechanism'] = 'PLAIN' + resp['value'] = auth + resp.send() + return True - def _handle_auth_success(self, xml): + def _handle_sasl_anonymous(self, xmpp): """ - SASL authentication succeeded. Restart the stream. + Attempt to authenticate using SASL ANONYMOUS. Arguments: - xml -- The SASL authentication success element. + xmpp -- The SleekXMPP connection instance. """ - self.authenticated = True - self.features = [] - raise RestartStream() + if xmpp.boundjid.user: + return False - def _handle_auth_fail(self, xml): - """ - SASL authentication failed. Disconnect and shutdown. + resp = sasl.Auth(xmpp) + resp['mechanism'] = 'ANONYMOUS' + resp.send() - Arguments: - xml -- The SASL authentication failure element. - """ - log.info("Authentication failed.") - self.event("failed_auth", direct=True) - self.disconnect() + return True - def _handle_bind_resource(self, xml): + def _handle_bind_resource(self, features): """ Handle requesting a specific resource. Arguments: - xml -- The bind feature element. + features -- The stream features stanza. """ log.debug("Requesting resource: %s" % self.boundjid.resource) - xml.clear() - iq = self.Iq(stype='set') + iq = self.Iq() + iq['type'] = 'set' + iq.enable('bind') if self.boundjid.resource: - res = ET.Element('resource') - res.text = self.boundjid.resource - xml.append(res) - iq.append(xml) + iq['bind']['resource'] = self.boundjid.resource response = iq.send() - bind_ns = 'urn:ietf:params:xml:ns:xmpp-bind' - self.set_jid(response.xml.find('{%s}bind/{%s}jid' % (bind_ns, - bind_ns)).text) + self.set_jid(response['bind']['jid']) self.bound = True + log.info("Node set to: %s" % self.boundjid.full) - session_ns = 'urn:ietf:params:xml:ns:xmpp-session' - if "{%s}session" % session_ns not in self.features or self.bindfail: + + if 'session' not in features['features']: log.debug("Established Session") self.sessionstarted = True self.session_started_event.set() self.event("session_start") - def _handle_start_session(self, xml): + def _handle_start_session(self, features): """ Handle the start of the session. Arguments: - xml -- The session feature element. + feature -- The stream features element. """ - if self.authenticated and self.bound: - iq = self.makeIqSet(xml) - response = iq.send() - log.debug("Established Session") - self.sessionstarted = True - self.session_started_event.set() - self.event("session_start") - else: - # Bind probably hasn't happened yet. - self.bindfail = True + iq = self.Iq() + iq['type'] = 'set' + iq.enable('session') + response = iq.send() + + log.debug("Established Session") + self.sessionstarted = True + self.session_started_event.set() + self.event("session_start") def _handle_roster(self, iq, request=False): """ diff --git a/sleekxmpp/stanza/__init__.py b/sleekxmpp/stanza/__init__.py index dbf7b86f..4481fa42 100644 --- a/sleekxmpp/stanza/__init__.py +++ b/sleekxmpp/stanza/__init__.py @@ -12,3 +12,7 @@ from sleekxmpp.stanza.stream_error import StreamError from sleekxmpp.stanza.iq import Iq from sleekxmpp.stanza.message import Message from sleekxmpp.stanza.presence import Presence +from sleekxmpp.stanza.stream_features import StreamFeatures +from sleekxmpp.stanza.bind import Bind +from sleekxmpp.stanza.session import Session + diff --git a/sleekxmpp/stanza/bind.py b/sleekxmpp/stanza/bind.py new file mode 100644 index 00000000..ae1f96f0 --- /dev/null +++ b/sleekxmpp/stanza/bind.py @@ -0,0 +1,26 @@ +""" + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. + + See the file LICENSE for copying permission. +""" + +from sleekxmpp.stanza import Iq, StreamFeatures +from sleekxmpp.xmlstream import ElementBase, ET, register_stanza_plugin + + +class Bind(ElementBase): + + """ + """ + + name = 'bind' + namespace = 'urn:ietf:params:xml:ns:xmpp-bind' + interfaces = set(('resource', 'jid')) + sub_interfaces = interfaces + plugin_attrib = 'bind' + + +register_stanza_plugin(Iq, Bind) +register_stanza_plugin(StreamFeatures, Bind) diff --git a/sleekxmpp/stanza/sasl.py b/sleekxmpp/stanza/sasl.py new file mode 100644 index 00000000..e55a72ad --- /dev/null +++ b/sleekxmpp/stanza/sasl.py @@ -0,0 +1,104 @@ +""" + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. + + See the file LICENSE for copying permission. +""" + +from sleekxmpp.stanza import StreamFeatures +from sleekxmpp.xmlstream import ElementBase, StanzaBase, ET +from sleekxmpp.xmlstream import register_stanza_plugin + + +class Mechanisms(ElementBase): + + """ + """ + + name = 'mechanisms' + namespace = 'urn:ietf:params:xml:ns:xmpp-sasl' + interfaces = set(('mechanisms', 'required')) + plugin_attrib = name + is_extension = True + + def get_required(self): + """ + """ + return True + + def get_mechanisms(self): + """ + """ + results = [] + mechs = self.findall('{%s}mechanism' % self.namespace) + if mechs: + for mech in mechs: + results.append(mech.text) + return results + + def set_mechanisms(self, values): + """ + """ + self.del_mechanisms() + for val in values: + mech = ET.Element('{%s}mechanism' % self.namespace) + mech.text = val + self.append(mech) + + def del_mechanisms(self): + """ + """ + mechs = self.findall('{%s}mechanism' % self.namespace) + if mechs: + for mech in mechs: + self.xml.remove(mech) + + +class Success(StanzaBase): + + """ + """ + + name = 'success' + namespace = 'urn:ietf:params:xml:ns:xmpp-sasl' + interfaces = set() + plugin_attrib = name + + +class Failure(StanzaBase): + + """ + """ + + name = 'failure' + namespace = 'urn:ietf:params:xml:ns:xmpp-sasl' + interfaces = set() + plugin_attrib = name + + +class Auth(StanzaBase): + + """ + """ + + name = 'auth' + namespace = 'urn:ietf:params:xml:ns:xmpp-sasl' + interfaces = set(('mechanism', 'value')) + plugin_attrib = name + + def setup(self, xml): + StanzaBase.setup(self, xml) + self.xml.tag = self.tag_name() + + def set_value(self, value): + self.xml.text = value + + def get_value(self): + return self.xml.text + + def del_value(self): + self.xml.text = '' + + +register_stanza_plugin(StreamFeatures, Mechanisms) diff --git a/sleekxmpp/stanza/session.py b/sleekxmpp/stanza/session.py new file mode 100644 index 00000000..c9d97157 --- /dev/null +++ b/sleekxmpp/stanza/session.py @@ -0,0 +1,25 @@ +""" + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. + + See the file LICENSE for copying permission. +""" + +from sleekxmpp.stanza import Iq, StreamFeatures +from sleekxmpp.xmlstream import ElementBase, ET, register_stanza_plugin + + +class Session(ElementBase): + + """ + """ + + name = 'session' + namespace = 'urn:ietf:params:xml:ns:xmpp-session' + interfaces = set() + plugin_attrib = 'session' + + +register_stanza_plugin(Iq, Session) +register_stanza_plugin(StreamFeatures, Session) diff --git a/sleekxmpp/stanza/stream_features.py b/sleekxmpp/stanza/stream_features.py new file mode 100644 index 00000000..5be2e55f --- /dev/null +++ b/sleekxmpp/stanza/stream_features.py @@ -0,0 +1,52 @@ +""" + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. + + See the file LICENSE for copying permission. +""" + +from sleekxmpp.xmlstream import ElementBase, StanzaBase, ET +from sleekxmpp.xmlstream import register_stanza_plugin + + +class StreamFeatures(StanzaBase): + + """ + """ + + name = 'features' + namespace = 'http://etherx.jabber.org/streams' + interfaces = set(('features', 'required', 'optional')) + sub_interfaces = interfaces + + def setup(self, xml): + StanzaBase.setup(self, xml) + self.values = self.values + + def get_features(self): + """ + """ + return self.plugins + + def set_features(self, value): + """ + """ + pass + + def del_features(self): + """ + """ + pass + + def get_required(self): + """ + """ + features = self['features'] + return [f for n, f in features.items() if f['required']] + + def get_optional(self): + """ + """ + features = self['features'] + return [f for n, f in features.items() if not f['required']] diff --git a/sleekxmpp/stanza/tls.py b/sleekxmpp/stanza/tls.py new file mode 100644 index 00000000..d85f9b49 --- /dev/null +++ b/sleekxmpp/stanza/tls.py @@ -0,0 +1,50 @@ +""" + SleekXMPP: The Sleek XMPP Library + Copyright (C) 2010 Nathanael C. Fritz + This file is part of SleekXMPP. + + See the file LICENSE for copying permission. +""" + +from sleekxmpp.stanza import StreamFeatures +from sleekxmpp.xmlstream import StanzaBase, ElementBase +from sleekxmpp.xmlstream import register_stanza_plugin + + +class STARTTLS(ElementBase): + + """ + """ + + name = 'starttls' + namespace = 'urn:ietf:params:xml:ns:xmpp-tls' + interfaces = set(('required',)) + plugin_attrib = name + + def get_required(self): + """ + """ + return True + + +class Proceed(StanzaBase): + + """ + """ + + name = 'proceed' + namespace = 'urn:ietf:params:xml:ns:xmpp-tls' + interfaces = set() + + +class Failure(StanzaBase): + + """ + """ + + name = 'failure' + namespace = 'urn:ietf:params:xml:ns:xmpp-tls' + interfaces = set() + + +register_stanza_plugin(StreamFeatures, STARTTLS) diff --git a/sleekxmpp/xmlstream/stanzabase.py b/sleekxmpp/xmlstream/stanzabase.py index b8a7ceaa..28f78f3c 100644 --- a/sleekxmpp/xmlstream/stanzabase.py +++ b/sleekxmpp/xmlstream/stanzabase.py @@ -1064,7 +1064,9 @@ class ElementBase(object): Defaults to True. """ stanza_ns = '' if top_level_ns else self.namespace - return tostring(self.xml, xmlns='', stanza_ns=stanza_ns) + return tostring(self.xml, xmlns='', + stanza_ns=stanza_ns, + top_level = not top_level_ns) def __repr__(self): """ @@ -1276,7 +1278,8 @@ class StanzaBase(ElementBase): stanza_ns = '' if top_level_ns else self.namespace return tostring(self.xml, xmlns='', stanza_ns=stanza_ns, - stream=self.stream) + stream=self.stream, + top_level = not top_level_ns) # To comply with PEP8, method names now use underscores. diff --git a/sleekxmpp/xmlstream/tostring/tostring.py b/sleekxmpp/xmlstream/tostring/tostring.py index 38b08d82..a6bb6ebc 100644 --- a/sleekxmpp/xmlstream/tostring/tostring.py +++ b/sleekxmpp/xmlstream/tostring/tostring.py @@ -7,7 +7,8 @@ """ -def tostring(xml=None, xmlns='', stanza_ns='', stream=None, outbuffer=''): +def tostring(xml=None, xmlns='', stanza_ns='', stream=None, + outbuffer='', top_level=False): """ Serialize an XML object to a Unicode string. @@ -26,6 +27,8 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None, outbuffer=''): stream -- The XML stream that generated the XML object. outbuffer -- Optional buffer for storing serializations during recursive calls. + top_level -- Indicates that the element is the outermost + element. """ # Add previous results to the start of the output. output = [outbuffer] @@ -39,14 +42,21 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None, outbuffer=''): else: tag_xmlns = '' + default_ns = '' + stream_ns = '' + if stream: + default_ns = stream.default_ns + stream_ns = stream.stream_ns + # Output the tag name and derived namespace of the element. namespace = '' - if tag_xmlns not in ['', xmlns, stanza_ns]: + if top_level and tag_xmlns not in ['', default_ns, stream_ns] or \ + tag_xmlns not in ['', xmlns, stanza_ns, stream_ns]: namespace = ' xmlns="%s"' % tag_xmlns - if stream and tag_xmlns in stream.namespace_map: - mapped_namespace = stream.namespace_map[tag_xmlns] - if mapped_namespace: - tag_name = "%s:%s" % (mapped_namespace, tag_name) + if stream and tag_xmlns in stream.namespace_map: + mapped_namespace = stream.namespace_map[tag_xmlns] + if mapped_namespace: + tag_name = "%s:%s" % (mapped_namespace, tag_name) output.append("<%s" % tag_name) output.append(namespace) diff --git a/sleekxmpp/xmlstream/tostring/tostring26.py b/sleekxmpp/xmlstream/tostring/tostring26.py index 11501780..3d1ca3d7 100644 --- a/sleekxmpp/xmlstream/tostring/tostring26.py +++ b/sleekxmpp/xmlstream/tostring/tostring26.py @@ -10,7 +10,8 @@ from __future__ import unicode_literals import types -def tostring(xml=None, xmlns='', stanza_ns='', stream=None, outbuffer=''): +def tostring(xml=None, xmlns='', stanza_ns='', stream=None, + outbuffer='', top_level=False): """ Serialize an XML object to a Unicode string. @@ -29,6 +30,8 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None, outbuffer=''): stream -- The XML stream that generated the XML object. outbuffer -- Optional buffer for storing serializations during recursive calls. + top_level -- Indicates that the element is the outermost + element. """ # Add previous results to the start of the output. output = [outbuffer] @@ -42,14 +45,21 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None, outbuffer=''): else: tag_xmlns = u'' + default_ns = '' + stream_ns = '' + if stream: + default_ns = stream.default_ns + stream_ns = stream.stream_ns + # Output the tag name and derived namespace of the element. namespace = u'' - if tag_xmlns not in ['', xmlns, stanza_ns]: + if top_level and tag_xmlns not in ['', default_ns, stream_ns] or \ + tag_xmlns not in ['', xmlns, stanza_ns, stream_ns]: namespace = u' xmlns="%s"' % tag_xmlns - if stream and tag_xmlns in stream.namespace_map: - mapped_namespace = stream.namespace_map[tag_xmlns] - if mapped_namespace: - tag_name = u"%s:%s" % (mapped_namespace, tag_name) + if stream and tag_xmlns in stream.namespace_map: + mapped_namespace = stream.namespace_map[tag_xmlns] + if mapped_namespace: + tag_name = u"%s:%s" % (mapped_namespace, tag_name) output.append(u"<%s" % tag_name) output.append(namespace) diff --git a/tests/test_tostring.py b/tests/test_tostring.py index 638e613a..e456d28e 100644 --- a/tests/test_tostring.py +++ b/tests/test_tostring.py @@ -102,11 +102,13 @@ class TestToString(SleekTest): """ Test that stanza objects are serialized properly. """ + self.stream_start() + utf8_message = '\xe0\xb2\xa0_\xe0\xb2\xa0' if not hasattr(utf8_message, 'decode'): # Python 3 utf8_message = bytes(utf8_message, encoding='utf-8') - msg = Message() + msg = self.Message() msg['body'] = utf8_message.decode('utf-8') expected = '<message><body>\xe0\xb2\xa0_\xe0\xb2\xa0</body></message>' result = msg.__str__() |