#author: James Henstridge <james@daa.com.au>

"""This is a class that implements an interface to mySQL databases, conforming
to the API published by the Python db-sig at
http://www.python.org/sigs/db-sig/DatabaseAPI.html

It is really just a wrapper for an older python interface to mySQL databases
called mySQL, which I modified to facilitate use of a cursor.  That module was
Joseph Skinner's port of the mSQL module by David Gibson, which was a modified
version of Anthony Baxter's msql module.

As an example, to add some extra (unprivelledged) users to your database system,
and delete them again:

>>> import mysqldb
>>> conn = mysqldb.mysqldb('mysql@localhost root rootpasswd')
>>> curs = mysqldb.cursor()
>>> curs.execute("insert into user (host, user) values ('%s', '%s')",
... [('localhost', 'linus'), ('somewhere.com.au', 'james')])
2
>>> curs.execute("select * from user")
>>> curs.fetchall()
 -- record listing --
>>> curs.execute("delete from user where host = 'somewhere.com.au' or user = 'linus'")
2
>>> curs.close()
>>> conn.close()

The argument to mysqldb.mysqldb is of the form 'db@host user pass',
'db@host user', 'db@host', 'db', 'db user pass' or 'db user'.

As always, the source is a good manual :-)

James Henstridge <james@daa.com.au>
"""

import mySQL
from string import upper, split, join

error = 'mysqldb.error'

_type = {}
for a in ('char', 'varchar', 'string', 'unhandled', '????'):
	_type[a] = 'STRING'
for a in ('tiny blob', 'medium blob', 'long blob', 'blob'):
	_type[a] = 'RAW'
for a in ('short', 'long', 'float', 'double', 'decimal'):
	_type[a] = 'NUMBER'
for a in ('date', 'time', 'datetime', 'timestamp'):
	_type[a] = 'DATE'
del a

def isDDL(q):
	return upper(split(q)[0]) in ('CREATE', 'ALTER', 'GRANT', 'REVOKE',
		'DROP', 'SET')
def isDML(q):
	return upper(split(q)[0]) in ('DELETE', 'INSERT', 'UPDATE', 'LOAD')
def isDQL(q):
	return upper(split(q)[0]) in ('SELECT', 'SHOW', 'DESC', 'DESCRIBE')

class Connection:
	"""This is the connection object for the mySQL database interface."""
	def __init__(self, host, user, passwd, db):
		try:
			self.__conn = mySQL.connect(host, user, passwd)
			self.__conn.selectdb(db)
		except mySQL.error, msg:
			raise error, msg
		self.__curs = Cursor(self.__conn)

	def __del__(self):
		self.close()

	def __getattr__(self, key):
		return getattr(self.__curs, key)

	def __setattr__(self, key, val):
		if key in ('arraysize', 'description', 'insert_id'):
			setattr(self.__curs, key, val)
		else:
			self.__dict__[key] = val

	def close(self):
		self.__conn = None

	def cursor(self):
		if self.__conn == None: raise error, "Connection is closed."
		return Cursor(self.__conn)

	def commit(self): pass
	def rollback(self): pass
	def callproc(self, params=None): pass

	# These functions are just here so that every action that is
	# covered by mySQL is covered by mysqldb.  They are not standard
	# DB API.  The list* methods are not included, since they can be
	# done with the SQL SHOW command.
	def create(self, dbname):
		"""This is not a standard part of Python DB API."""
		return self.__conn.create(dbname)
	def drop(self, dbname):
		"""This is not a standard part of Python DB API."""
		return self.__conn.drop(dbname)
	def reload(self):
		"""This is not a standard part of Python DB API."""
		return self.__conn.reload()
	def shutdown(self):
                """This is not a standard part of Python DB API."""
		return self.__conn.shutdown()


class Cursor:
	"""A cursor object for use with connecting to mySQL databases."""
	def __init__(self, conn):
		self.__conn = conn
		self.__res = None
		self.arraysize = 1
		self.__dict__['description'] = None
		self.__open = 1
		self.insert_id = 0

	def __del__(self):
		self.close()

	def __setattr__(self, key, val):
		if key == 'description':
			raise error, "description is a read-only attribute."
		else:
			self.__dict__[key] = val

	def __delattr__(self, key):
		if key in ('description', 'arraysize', 'insert_id'):
			raise error, "%s can't be deleted." % (key,)
		else:
			del self.__dict__[key]

	def close(self):
		self.__conn = None
		self.__res = None
		self.__open = 0

	def execute(self, op, params=None):
		if not self.__open: raise error, "Cursor has been closed."
		if params:
			if type(params[0]) not in (type(()), type([])):
				params = [params]
			if isDDL(op):
				self.__dict__['description'] = None
				try:
					for x in params:
						self.__res = \
						self.__conn.querycursor(op % x)
					self.insert_id = self.__res.insert_id()
				except mySQL.error, msg:
					raise error, msg
				return 1
			if isDML(op):
				self.__dict__['description'] = None
				af = 0
				try:
					for x in params:
						self.__res = \
						self.__conn.querycursor(op % x)
						af =af+self.__res.affectedrows()
					self.insert_id = self.__res.insert_id()
				except mySQL.error, msg:
					raise error, msg
				return af
			if isDQL(op):
				try:
					self.__res = self.__conn.querycursor(
						op % params[-1])
					self.insert_id = self.__res.insert_id()
					f = self.__res.fields()
				except mySQL.error, msg:
					raise error, msg
				self.__dict__['description'] = tuple(map(
					lambda x: (x[0], _type[x[2]], x[3],
					x[3]), f))
				return None
		else:
			try:
				self.__res = self.__conn.querycursor(op)
				self.insert_id = self.__res.insert_id()
			except mySQL.error, msg:
				raise error, msg
			self.__dict__['description'] = None
			if isDDL(op):
				return 1
			elif self.__res.affectedrows() != -1:
				return self.__res.affectedrows()
			else:
				try:
					f = self.__res.fields()
					#print f
				except mySQL.error, msg:
					raise error, msg
				self.__dict__['description'] = tuple(map(
					lambda x: (x[0], _type[x[2]], x[3],
					x[3], x[1]), f))
				return None
	def fetchone(self):
		if not self.__res: raise error, "no query made yet."
		try:
			return self.__res.fetchone()
		except mySQL.error, msg:
			raise error, msg

	def fetchmany(self, size=None):
		if not self.__res: raise error, "no query made yet."
		try:
			return self.__res.fetchmany(size or self.arraysize)
		except mySQL.error, msg:
			raise error, msg

	def fetchall(self):
		if not self.__res: raise error, "no query made yet."
		try:
			return self.__res.fetchall()
		except mySQL.error, msg:
			raise error, msg


	def __getColKeys(self):
		"""This is not a standard part of Python DB API."""
		duplicates = {}
		for i in self.description:
			if duplicates.has_key(i[0]):
				duplicates[i[0]] = 1
			else:
				duplicates[i[0]] = 0
		colKeys = []
		for i in range(len(self.description)):
			if duplicates[self.description[i][0]]: 
				colKeys.append(join([
					self.description[i][4],
					self.description[i][0]], '.'))
			else:
				colKeys.append(self.description[i][0])
		return colKeys

	def fetchoneDict(self):
		"""This is not a standard part of Python DB API."""
		if not self.__res: raise error, "no query made yet."
		colKeys = self.__getColKeys()
		try:
			tmp = self.__res.fetchone()
		except mySQL.error, msg:
			raise error, msg
		result = {}
		for i in range(len(tmp)):
			result[colKeys[i]] = tmp[i]
		return result

	def fetchmanyDict(self, size=None):
		"""This is not a standard part of Python DB API."""
		if not self.__res: raise error, "no query made yet."
		colKeys = self.__getColKeys()
		try:
			tmp =  self.__res.fetchmany(size or self.arraysize)
		except mySQL.error, msg:
			raise error, msg
		result = []
		for i in range(len(tmp)):
			result.append({})
			for j in range(len(tmp[i])):
				result[i][colKeys[j]] = tmp[i][j]
		return result

	def fetchallDict(self):
		"""This is not a standard part of Python DB API."""
		if not self.__res: raise error, "no query made yet."
		colKeys = self.__getColKeys()
		try:
			tmp = self.__res.fetchall()
		except mySQL.error, msg:
			raise error, msg
		result = []
		for i in range(len(tmp)):
			result.append({})
			for j in range(len(tmp[i])):
				result[i][colKeys[j]] = tmp[i][j]
		return result
						


	def setinputsizes(self, sizes): pass
	def setoutputsize(self, size, col=None): pass


def mysqldb(connect_string):
	"""Makes a connection to the MySQL server.  The Argument should be of
	the form 'db@host user pass' or 'db@host user' or 'db@host' or 'db'
	or 'db user pass' or 'db user', where db is the database name, host
	is the server's host name, user is your user name, and pass is your
	password."""
	val = split(connect_string)
	if len(val) == 0: raise error, "no database specified"
	while len(val) < 3: val.append('')
	dh = split(val[0], '@')
	if len(dh) == 0: raise error, "no database specified"
	while len(dh) < 2: dh.append('')
	if dh[1] == '': dh[1] = 'localhost'
	return Connection(dh[1], val[1], val[2], dh[0])

