from xml.etree import cElementTree as ET
import logging
import traceback

class JID(object):
	def __init__(self, jid):
		self.jid = jid
	
	def __getattr__(self, name):
		if name == 'resource':
			return self.jid.split('/', 1)[-1]
		elif name == 'user':
			return self.jid.split('@', 1)[0]
		elif name == 'server':
			return self.jid.split('@', 1)[-1].split('/', 1)[0]
		elif name == 'full':
			return self.jid
		elif name == 'bare':
			return self.jid.split('/', 1)[0]
	
	def __str__(self):
		return self.jid

class ElementBase(object):
	name = 'stanza'
	plugin_attrib = 'plugin'
	namespace = 'jabber:client'
	interfaces = set(('type', 'to', 'from', 'id', 'payload'))
	types = set(('get', 'set', 'error', None, 'unavailable', 'normal', 'chat'))
	sub_interfaces = tuple()
	plugin_attrib_map = {}
	plugin_tag_map = {}
	subitem = None

	def __init__(self, xml=None, parent=None):
		self.attrib = self # backwards compatibility hack
		self.parent = parent
		self.xml = xml
		self.plugins = {}
		self.iterables = []
		self.idx = 0
		if not self.setup(xml):
			for child in self.xml.getchildren():
				if child.tag in self.plugin_tag_map:
					self.plugins[self.plugin_tag_map[child.tag].plugin_attrib] = self.plugin_tag_map[child.tag](xml=child, parent=self)
				if self.subitem is not None and child.tag == "{%s}%s" % (self.subitem.namespace, self.subitem.name):
					self.iterables.append(self.subitem(xml=child, parent=self))

	def __iter__(self):
		self.idx = 0
		return self
	
	def __next__(self):
		self.idx += 1
		if self.idx + 1 > len(self.iterables):
			self.idx = 0
			raise StopIteration
		return self.affiliations[self.idx]
	
	def __len__(self):
		return len(self.iterables)
	
	def append(self, item):
		if not isinstance(item, ElementBase):
			raise TypeError
		self.xml.append(item.xml)
		return self.iterables.append(item)
	
	def pop(self, idx=0):
		aff = self.iterables.pop(idx)
		self.xml.remove(aff.xml)
		return aff
	
	def keys(self):
		out = []
		out += [x for x in self.interfaces]
		out += [x for x in self.plugins]
		if self.iterables:
			out.append('substanzas')
		return tuple(out)
	
	def find(self, item):
		return self.iterables.find(item)

	def match(self, xml):
		return xml.tag == self.tag
	
	def find(self, xpath): # for backwards compatiblity, expose elementtree interface
		return self.xml.find(xpath)
	
	def setup(self, xml=None):
		if self.xml is None:
			self.xml = xml
		if self.xml is None:
			for ename in self.name.split('/'):
				new = ET.Element("{%(namespace)s}%(name)s" % {'name': self.name, 'namespace': self.namespace})
				if self.xml is None:
					self.xml = new
				else:
					self.xml.append(new)
			if self.parent is not None:
				self.parent.xml.append(self.xml)
			return True #had to generate XML
		else:
			return False

	def enable(self, attrib):
		self.initPlugin(attrib)
		return self
	
	def initPlugin(self, attrib):
		if attrib not in self.plugins:
			self.plugins[attrib] = self.plugin_attrib_map[attrib](parent=self)
	
	def __getitem__(self, attrib):
		if attrib == 'substanzas':
			return self.iterables
		elif attrib in self.interfaces:
			if hasattr(self, "get%s" % attrib.title()):
				return getattr(self, "get%s" % attrib.title())()
			else:
				if attrib in self.sub_interfaces:
					return self._getSubText(attrib)
				else:
					return self._getAttr(attrib)
		elif attrib in self.plugin_attrib_map:
			if attrib not in self.plugins: self.initPlugin(attrib)
			return self.plugins[attrib]
		else:
			return ''
	
	def __setitem__(self, attrib, value):
		if attrib in self.interfaces:
			if value is not None:
				if hasattr(self, "set%s" % attrib.title()):
					getattr(self, "set%s" % attrib.title())(value,)
				else:
					if attrib in self.sub_interfaces:
						return self._setSubText(attrib, text=value)
					else:
						self._setAttr(attrib, value)
			else:
				self.__delitem__(attrib)
		elif attrib in self.plugin_attrib_map:
			if attrib not in self.plugins: self.initPlugin(attrib)
			self.initPlugin(attrib)
			self.plugins[attrib][attrib] = value
		return self
	
	def __delitem__(self, attrib):
		if attrib.lower() in self.interfaces:
			if hasattr(self, "del%s" % attrib.title()):
				getattr(self, "del%s" % attrib.title())()
			else:
				if attrib in self.sub_interfaces:
					return self._delSub(attrib)
				else:
					self._delAttr(attrib)
		elif attrib in self.plugin_attrib_map:
			if attrib in self.plugins:
				del self.plugins[attrib]
		return self
	
	def __eq__(self, other):
		values = self.getValues()
		for key in other:
			if key not in values or values[key] != other[key]:
				return False
		return True
	
	def _setAttr(self, name, value):
		if value is None or value == '':
			self.__delitem__(name)
		else:
			self.xml.attrib[name] = value
	
	def _delAttr(self, name):
		if name in self.xml.attrib:
			del self.xml.attrib[name]
	
	def _getAttr(self, name):
		return self.xml.attrib.get(name, '')
	
	def _getSubText(self, name):
		stanza = self.xml.find("{%s}%s" % (self.namespace, name))
		if stanza is None or stanza.text is None:
			return ''
		else:
			return stanza.text
	
	def _setSubText(self, name, attrib={}, text=None):
		if text is None or text == '':
			return self.__delitem__(name)
		stanza = self.xml.find("{%s}%s" % (self.namespace, name))
		if stanza is None:
			#self.xml.append(ET.Element("{%s}%s" % (self.namespace, name), attrib))
			stanza = ET.Element("{%s}%s" % (self.namespace, name))
			self.xml.append(stanza)
		stanza.text = text
		return stanza
		
	def _delSub(self, name):
		for child in self.xml.getchildren():
			if child.tag == "{%s}%s" % (self.namespace, name):
				self.xml.remove(child)
	
	def getValues(self):
		out = {}
		for interface in self.interfaces:
			out[interface] = self[interface]
		for pluginkey in self.plugins:
			out[pluginkey] = self.plugins[pluginkey].getValues()
		if self.iterables:
			iterables = [x.getValues() for x in self.iterables]
			out['substanzas'] = iterables
		return out
	
	def setValues(self, attrib):
		for interface in attrib:
			if interface == 'substanzas':
				for subdict in attrib['substanzas']:
					sub = self.subitem(parent=self)
					sub.setValues(subdict)
					self.iterables.append(sub)
			elif interface in self.interfaces:
				self[interface] = attrib[interface]
			elif interface in self.plugin_attrib_map and interface not in self.plugins:
				self.initPlugin(interface)
			if interface in self.plugins:
				self.plugins[interface].setValues(attrib[interface])
		return self
	
	def append(self, xml):
		self.xml.append(xml)
		return self
	
	def __del__(self):
		if self.parent is not None:
			self.parent.xml.remove(self.xml)

class StanzaBase(ElementBase):
	name = 'stanza'
	namespace = 'jabber:client'
	interfaces = set(('type', 'to', 'from', 'id', 'payload'))
	types = set(('get', 'set', 'error', None, 'unavailable', 'normal', 'chat'))
	sub_interfaces = tuple()

	def __init__(self, stream=None, xml=None, stype=None, sto=None, sfrom=None, sid=None):
		self.stream = stream
		ElementBase.__init__(self, xml)
		if stype is not None:
			self['type'] = stype
		if sto is not None:
			self['to'] = sto
		if sfrom is not None:
			self['from'] = sfrom
		if stream is not None:
			self.namespace = stream.default_ns
		self.tag = "{%s}%s" % (self.namespace, self.name)
	
	def setType(self, value):
		if value in self.types:
				self.xml.attrib['type'] = value
		return self

	def getPayload(self):
		return self.xml.getchildren()
	
	def setPayload(self, value):
		self.xml.append(value)
	
	def delPayload(self):
		self.clear()
	
	def clear(self):
		for child in self.xml.getchildren():
			self.xml.remove(child)
		#for plugin in list(self.plugins.keys()):
		#	del self.plugins[plugin]
	
	def reply(self):
		self['from'], self['to'] = self['to'], self['from']
		self.clear()
		return self
	
	def error(self):
		self['type'] = 'error'
	
	def getTo(self):
		return JID(self._getAttr('to'))
	
	def setTo(self, value):
		return self._setAttr('to', str(value))
	
	def getFrom(self):
		return JID(self._getAttr('from'))
	
	def setFrom(self, value):
		return self._setAttr('from', str(value))
	
	def unhandled(self):
		pass
	
	def exception(self, e):
		logging.error(traceback.format_tb(e))
	
	def send(self):
		self.stream.sendRaw(str(self))

	def __str__(self, xml=None, xmlns='', stringbuffer=''):
		if xml is None:
			xml = self.xml
		newoutput = [stringbuffer]
		#TODO respect ET mapped namespaces
		itag = xml.tag.split('}', 1)[-1]
		if '}' in xml.tag:
			ixmlns = xml.tag.split('}', 1)[0][1:]
		else:
			ixmlns = ''
		nsbuffer = ''
		if xmlns != ixmlns and ixmlns != '' and ixmlns != self.namespace:
			if self.stream is not None and ixmlns in self.stream.namespace_map:
				if self.stream.namespace_map[ixmlns] != '':
					itag = "%s:%s" % (self.stream.namespace_map[ixmlns], itag)
			else:
				nsbuffer = """ xmlns="%s\"""" % ixmlns
		if ixmlns not in ('', xmlns, self.namespace):
			nsbuffer = """ xmlns="%s\"""" % ixmlns
		newoutput.append("<%s" % itag)
		newoutput.append(nsbuffer)
		for attrib in xml.attrib:
			if '{' not in attrib:
				newoutput.append(""" %s="%s\"""" % (attrib, self.xmlesc(xml.attrib[attrib])))
		if len(xml) or xml.text or xml.tail:
			newoutput.append(">")
			if xml.text:
				newoutput.append(self.xmlesc(xml.text))
			if len(xml):
				for child in xml.getchildren():
					newoutput.append(self.__str__(child, ixmlns))
			newoutput.append("</%s>" % (itag, ))
			if xml.tail:
				newoutput.append(self.xmlesc(xml.tail))
		elif xml.text:
			newoutput.append(">%s</%s>" % (self.xmlesc(xml.text), itag))
		else:
			newoutput.append(" />")
		return ''.join(newoutput)

	def xmlesc(self, text):
		text = list(text)
		cc = 0
		matches = ('&', '<', '"', '>', "'")
		for c in text:
			if c in matches:
				if c == '&':
					text[cc] = '&amp;'
				elif c == '<':
					text[cc] = '&lt;'
				elif c == '>':
					text[cc] = '&gt;'
				elif c == "'":
					text[cc] = '&apos;'
				else:
					text[cc] = '&quot;'
			cc += 1
		return ''.join(text)