# -*- coding: utf-8 -*-

import dateutil.parser

class inspectdb(object):

    @staticmethod
    def _pg_alltabs(odb):
        SQL = "SELECT tablename FROM pg_tables WHERE schemaname = 'public';"
        res = odb.executesql(SQL)
        for r in res:
            yield r[0]

    @classmethod
    def gettabs(cls, k, odb):
        if k in myconf and "tables" in myconf.take(k):
            tables = myconf.take(k+".tables", cast=lambda v: v.split(','))
        else:
            tables = [t for t in cls._pg_alltabs(odb)]
        return tables

    @staticmethod
    def _pg_allfieldsraw(odb, tablename):
        SQL = """SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
FROM pg_attribute a
JOIN pg_class b ON (a.attrelid = b.relfilenode)
WHERE b.relname = '%(tablename)s' and a.attstattarget = -1;""" % locals()

        try:
            res = odb.executesql(SQL)
        except:
            odb.rollback()
        else:
            for r in res:
                yield r

    @classmethod
    def _pg_allfields(cls, odb, tablename):

        def _get_conf(column_name, data_type):
            conf = {"rname": column_name if not column_name in (column_name.upper(), column_name.capitalize()) else '"%s"' % column_name}
            if column_name == "id":
                conf.update({"type": "id"})
            elif data_type.startswith("character varying"):
                conf.update({"type": "string", "length": int(data_type[18:-1])})
            elif data_type.startswith("character"):
                conf.update({"type": "string", "length": int(data_type[10:-1])})
            elif data_type=="text":
                conf.update({"type": "text"})
            elif data_type.startswith("geometry"):
                conf.update({"type": "geometry()"})
            elif data_type == "integer":
                conf.update({"type": "integer"})
            elif data_type == "double precision":
                conf.update({"type": "double"})
            elif data_type=="timestamp without time zone":
                conf.update({"type": "datetime"})
            elif data_type in ("date", "json",):
                conf.update({"type": data_type})
            else:
                raise NotImplementedError()

            if "inspectdb" in myconf and column_name in myconf.take("inspectdb"):
                conf["fieldname"] = myconf.take("inspectdb."+column_name)
                if myconf.take("inspectdb."+column_name) == "id":
                    conf["type"] = "id"

            return conf

        for r in cls._pg_allfieldsraw(odb, tablename):
            try:
                yield r[0], dict(_get_conf(*r), fieldname=r[0])
            except:
                pass

    @classmethod
    def getfields(cls, odb, tablename):
        if tablename in myconf and "fields" in myconf.take(tablename):
            fieldnames = myconf.take(tablename+".fields", cast=lambda v: v.split(","))
            rawconfs = {k: v for k,v in cls._pg_allfields(odb, tablename)}
            #import pdb; pdb.set_trace()
            for fn in fieldnames:
                fconf = rawconfs[fn]
                if tablename+":"+fn in myconf:
                    fconf.update(myconf.take(tablename+":"+fn))
                yield Field(**fconf)
        else:
            for fn, fconf in cls._pg_allfields(odb, tablename):
                k = tablename+":"+fn
                if k in myconf:
                    fconf.update(myconf.take(k))
                yield Field(**fconf)

_k_ = "inspectdb:"
def loopOconns(flt=None):
    for k,v in myconf.iteritems():
        if  k.startswith(_k_):
            if flt is None or flt(myconf.take(k)):
                yield  k[len(_k_):], myconf.take(k)

#def loopOfields(tablename):
#    """ DEPRECATED """
#    for k,v in myconf.iteritems():
#        if k.startswith(tablename+":"):
#            yield myconf.take(k)

odbs = {name: DAL(nfo["uri"], pool_size=nfo["pool_size"], migrate=nfo["migrate"], check_reserved=['all']) \
    for name, nfo in loopOconns()}

for k, odb in odbs.iteritems():
    for tablename in inspectdb.gettabs(_k_+k, odb):
        if tablename in myconf:
            tabconf = {k[1:]: v for k,v in myconf.take(tablename).iteritems() if k.startswith("_")}
        else:
            tabconf = {}
        fields = [f for f in inspectdb.getfields(odb, tablename)]
        if any([(f.type=="id") for f in fields]):
            odbs[k].define_table(tablename, *fields, **tabconf)
        # TODO: else I can log it!

response.menu += [
    (STRONG(SPAN(_class="glyphicon glyphicon-sunglasses", **{"_aria-hidden": "true"}), " ", T("Inspect dbs"), _style="color: yellow;"), False, "#", [
        (conn, False, URL("plugin_inspectdb", "index", args=(conn,)),) \
    for conn in odbs],),
]

class DBService(object):
    """ """

    @staticmethod
    def _cast(f, v):
        """
        f @Field  : object;
        v @string : value to be inserted
        """
        if f.type in (None, 'string', 'text'):
            return v
        elif f.type == 'integer':
            return int(v)
        elif f.type == 'boolean':
            return bool(v)
        elif f.type == 'double':
            return float(v)
        elif f.type == 'date':
            return  dateutil.parser.parse(v).date()
        elif f.type == 'datetime':
            return dateutil.parser.parse(v)
        elif f.type == 'json':
            return json.loads(v) if isinstance(v, basestring) else v
        else:
            raise NotImplementedError()

    @classmethod
    def insert(cls, dbname, tablename, **kw):
        """ """
        @auth.requires_login()
        def _main():
            tab = odbs[dbname][tablename]
            return tab.insert(**{k: cls._cast(tab[k], v) for k,v in kw.iteritems()})
        return _main()

    @classmethod
    def bulk_insert(cls, dbname, tablename, _data):
        """
        data @list : list of dictionaries
        """
        data = json.loads(_data)
        @auth.requires_login()
        def _main():
            return odbs[dbname][tablename].bulk_insert(map(lambda kw: {k: cls._cast(tab[k], v) for k,v in kw.iteritems()}, data))
        return _main()

@service.json
def db_insert(dbname, tablename, **kw):
    """ """
    return dict(id = DBService.insert(dbname, tablename, **kw))

@service.json
def db_bulk_insert(dbname, tablename, data):
    """ """
    return dict(ids = DBService.bulk_insert(dbname, tablename, data))