#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdarg.h>

#if defined(WIN32) || defined(__WIN32__)
#include <windows.h>
#endif  /* win32 */

#include <mysql/mysql.h>
#include "db.h"

#define MAX_QUERY_SIZE      8192

int vsnprintf(char *str, size_t size, const char *format, va_list ap);

static MYSQL *db;
static MYSQL_RES *res;

int db_connect(const char *host) {
    if(db) {
        mysql_close(db);
    }

    db = mysql_init(0);
    if(!(mysql_real_connect(db, host, "ina", "ina", "ina", 0, 0, 0))) {
        return -1;
    }

    return 0;
}

void db_disconnect(void) {
    mysql_close(db);
    db = 0;
}

int db_query(const char *str, ...) {
    static char qry[MAX_QUERY_SIZE];
    va_list arg_list;

    va_start(arg_list, str);
    vsnprintf(qry, MAX_QUERY_SIZE, str, arg_list);
    va_end(arg_list);

    if(mysql_real_query(db, qry, strlen(qry)) != 0) {
        return -1;
    }

    if(res) {
        mysql_free_result(res);
        res = 0;
    }

    return 0;
}

const char **db_get_row(void) {
    MYSQL_ROW row;

    if(!res) {
        if(!(res = mysql_store_result(db))) {
            return 0;
        }
    }

    if(!(row = mysql_fetch_row(res))) {
        mysql_free_result(res);
        res = 0;
        return 0;
    }

    return (const char**)row;
}

unsigned long db_get_last_id(void) {
    return mysql_insert_id(db);
}

/* -------------------------------------------- */

int ina_auth(const char *user, const char *pass) {
    const char **row;
    char db_pass_hash[50];

    if(db_query("SELECT passwd FROM users WHERE uname = '%s'", user) == -1) {
        return -1;
    }

    if(!(row = db_get_row())) {
        return 0;
    }
    sprintf(db_pass_hash, row[0]);

    if(db_query("SELECT SHA1('%s')", pass) == -1 || !(row = db_get_row())) {
        return -1;
    }
    pass = row[0];

    if(strcmp(db_pass_hash, pass) == 0) {
        return 1;
    }
    return 0;
}