(**************************************************************************)
(*                   Cameleon                                             *)
(*                                                                        *)
(*      Copyright (C) 2002 Institut National de Recherche en Informatique et   *)
(*      en Automatique. All rights reserved.                              *)
(*                                                                        *)
(*      This program is free software; you can redistribute it and/or modify  *)
(*      it under the terms of the GNU General Public License as published by  *)
(*      the Free Software Foundation; either version 2 of the License, or  *)
(*      any later version.                                                *)
(*                                                                        *)
(*      This program is distributed in the hope that it will be useful,   *)
(*      but WITHOUT ANY WARRANTY; without even the implied warranty of    *)
(*      MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the     *)
(*      GNU General Public License for more details.                      *)
(*                                                                        *)
(*      You should have received a copy of the GNU General Public License  *)
(*      along with this program; if not, write to the Free Software       *)
(*      Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA          *)
(*      02111-1307  USA                                                   *)
(*                                                                        *)
(*      Contact: Maxence.Guesdon@inria.fr                                *)
(**************************************************************************)

(** The class for MySQL databases through the OCaml-mysql library. *)

open Dbf_types.Current

module M = Dbf_messages

type col_type =
  | TINYINT
  | SMALLINT
  | MEDIUMINT
  | INT
  | BIGINT
  | FLOAT_X
  | FLOAT
  | DOUBLE
  | DECIMAL

  | DATE
  | DATETIME
  | TIMESTAMP
  | TIME
  | YEAR

  | CHAR_X
  | VARCHAR_M
  | NCHAR_X
  | NVARCHAR_M
  | TINYBLOB
  | TINYTEXT
  | BLOB
  | TEXT
  | MEDIUMBLOB
  | MEDIUMTEXT
  | LONGBLOB
  | LONGTEXT
  | ENUM_X
  | SET_X

let tinyint = "TINYINT" 
let smallint = "SMALLINT"
let mediumint = "MEDIUMINT"
let int = "INT"
let bigint = "BIGINT"
let float_x = "FLOAT(M)"
let float = "FLOAT"
let double = "DOUBLE"
let decimal = "DECIMAL"

let date = "DATE"
let datetime = "DATETIME"
let timestamp = "TIMESTAMP"
let time = "TIME"
let year = "YEAR"

let char_x = "CHAR(M)"
let varchar_x = "VARCHAR(M)"
let nchar_x = "NATIONAL CHAR(M)"
let nvarchar_x = "NATIONAL VARCHAR(M)"
let tinyblob = "TINYBLOB"
let tinytext = "TINYTEXT"
let blob = "BLOB"
let text = "TEXT"
let mediumblob = "MEDIUMBLOB"
let mediumtext = "MEDIUMTEXT"
let longblob = "LONGBLOB"
let longtext = "LONGTEXT"
let enum_x = "ENUM(vals)"
let set_x = "SET(vals)"

let other = "other"

let type_strings = [
  tinyint, TINYINT ; 
  smallint, SMALLINT ;
  mediumint, MEDIUMINT ;
  int, INT ;
  bigint, BIGINT ;
  float_x, FLOAT_X ;
  float, FLOAT ;
  double, DOUBLE ;
  decimal, DECIMAL ;

  date, DATE ;
  datetime, DATETIME ;
  timestamp, TIMESTAMP ;
  time, TIME ;
  year, YEAR ;

  char_x, CHAR_X ;
  varchar_x, VARCHAR_M ;
  nchar_x, NCHAR_X ;
  nvarchar_x, NVARCHAR_M ;
  tinyblob, TINYBLOB ;
  tinytext, TINYTEXT ;
  blob, BLOB ;
  text, TEXT ;
  mediumblob, MEDIUMBLOB ;
  mediumtext, MEDIUMTEXT ;
  longblob, LONGBLOB ;
  longtext, LONGTEXT ;
  enum_x, ENUM_X ;
  set_x, SET_X ;
] 


class mysql_spec =
  object
    method dbms = Mysql
    method name = "Mysql"

    method types = [
      tinyint, None ; 
      smallint, None ;
      mediumint, None ;
      int, None ;
      bigint, None ;
      float_x, Some "M" ;
      float, None ;
      double, None;
      decimal, None ;

      date, None ;
      datetime, None ;
      timestamp, None ;
      time, None ;
      year, None ;

      char_x, Some "M" ;
      varchar_x, Some "M" ;
      nchar_x, Some "M";
      nvarchar_x, Some "M" ;
      tinyblob, None ;
      tinytext, None ;
      blob, None ;
      text, None ;
      mediumblob, None ;
      mediumtext, None ;
      longblob, None ;
      longtext, None ;
      enum_x, Some "vals";
      set_x, Some "vals" ;
    ] 
	
    method header = 
"open Mysql


let mExecError req = \""^Dbf_messages.mExecError^" \"^req

let string_of_pred_list l =
  let rec iter acc = function
      [] -> acc
    | (c,Some v) :: [] -> acc^c^\"=\"^v
    | (c,None) :: [] -> acc^c^\" IS NULL\"
    | (c,Some v) :: q -> iter (acc^c^\"=\"^v^\" AND \") q
    | (c,None) :: q -> iter (acc^c^\" IS NULL AND \") q
  in
  iter \"\" l\n\n

let string_of_sqlstring = Mysql.str2ml

let sqlstring_of_string = Mysql.ml2str

(* Return a the string [\"NULL\"] if None or the given string if Some.*)
let string_or_null s_opt =
  match s_opt with
    None -> \"NULL\"
  | Some s -> s

(* Apply a function to an optional value. *)
let apply_opt f v_opt =
   match v_opt with
     None -> None
   | Some v -> Some (f v)

(* Generic update function used in the [update] function of each table.*)
let update_table db table
    pred_list set_list =
  let query = \"update \"^table^\" set \"^
      (String.concat \", \" set_list)^
	 (match pred_list with
	   [] -> \"\"
	 | _ -> \" where \"^(String.concat \" AND \" pred_list))
  in
  try
    let res = Mysql.exec db query in
    match Mysql.errmsg db with
      None -> ()
    | Some s ->
        raise (Failure ((mExecError query)^\" \"^s))
  with
   | Error s -> raise (Failure s)

type db = Mysql.dbd

let connect
    ?(host : string option) 
    ?(port : int option) 
    ?password user database =
  try 
    Mysql.connect 
      { dbhost = host ;
	dbname = Some database;
	dbport = port ;
	dbpwd = password ;
	dbuser = Some user ;
      } 
  with Mysql.Error s -> raise (Failure s)

let disconnect db = Mysql.disconnect db
"


    method col_attributes = [ ("Misc",  Att_string) ]

    method col_keys = ([Primary_key ; Key] : t_key list)

    method funs_2ml =
      List.map (fun s -> "Mysql."^s)
	[ 
	  "blob2ml" ;
	  "date2ml" ;
	  "datetime2ml" ;
	  "enum2ml" ;
	  "float2ml" ;
	  "int2ml" ;
	  "int322ml" ;
	  "int642ml" ;
	  "nativeint2ml" ;
	  "set2ml" ;
	  "str2ml" ;
	  "time2ml" ;
	  "timestamp2ml" ;
	  "year2ml" ;	  
	] 

    method funs_ml2 =
      List.map (fun s -> "Mysql."^s)
	[ 
	  "ml2blob" ;
	  "ml2int" ;
	  "ml322int" ;
	  "ml642int" ;
	  "ml2float" ;
	  "ml2enum" ;
	  "ml2set" ;
	  "ml2str" ;
	  "ml2datetime" ;
	  "ml2datetimel" ;
	  "ml2date" ;
	  "ml2datel" ;
	  "ml2time" ;
	  "ml2timel" ;
	  "ml2year" ;
	  "ml2timestamp" ;
	  "ml2timestampl" ;
	] 
  end


let spec = ((new mysql_spec) :> Dbf_dbms.dbms_spec)


let p = Format.fprintf

(** Code generation to use OCaml-MySQL. *)
class mysql_gen spec =
  object (self)
    inherit Dbf_odbc.odbc_gen spec

    (** Get the SQL code to define the given column. *)
    method column_def c =
      let cdbms = List.assoc spec#dbms c.col_dbms in
      let t = 
	let (s, v_opt, args_opt) = cdbms.col_type_sql in
	let s_args = match args_opt with None -> "" | Some a -> a in
        let maybe_int_code s =
          try ignore (int_of_string s); s
          with _ -> Printf.sprintf "\"^(string_of_int (%s))^\"" s
        in
	let s_type = 
	  match v_opt with
	  | _ when s = tinyint -> "TINYINT" 
	  | _ when s = smallint -> "SMALLINT"
	  | _ when s = mediumint -> "MEDIUMINT"
	  | _ when s = int -> "INT"
	  | _ when s = bigint -> "BIGINT"
	  | Some x when s = float_x -> "FLOAT("^x^")"
	  | _ when s = float -> "FLOAT"
	  | _ when s = double -> "DOUBLE"
	  | _ when s = decimal -> "DECIMAL"

	  | _ when s = date -> "DATE"
	  | _ when s = datetime -> "DATETIME"
	  | _ when s = timestamp -> "TIMESTAMP"
	  | _ when s = time -> "TIME"
	  | _ when s = year -> "YEAR"
	  
	  | Some x when s = char_x -> "CHAR("^(maybe_int_code x)^")"
	  | Some x when s = varchar_x -> "VARCHAR("^(maybe_int_code x)^")"
	  | Some x when s = nchar_x -> "NCHAR("^(maybe_int_code x)^")"
	  | Some x when s = nvarchar_x -> "NVARCHAR("^(maybe_int_code x)^")"
	  | _ when s = tinyblob -> "TINYBLOB"
	  | _ when s = tinytext -> "TINYTEXT"
	  | _ when s = blob -> "BLOB"
	  | _ when s = text -> "TEXT"
	  | _ when s = mediumblob -> "MEDIUMBLOB"
	  | _ when s = mediumtext -> "MEDIUMTEXT"
	  | _ when s = longblob -> "LONGBLOB"
	  | _ when s = longtext -> "LONGTEXT"
	  | Some vals when s = enum_x -> "ENUM("^vals^")"
	  | Some vals when s = set_x -> "SET("^vals^")"
          | Some o when s = other -> o
	  | _ -> raise (Failure (Format.sprintf "%s: %s %s" 
				   M.incorrect_type_definition
				   s (match v_opt with None -> "<None>" | Some n -> n)))
	in
	s_type^" "^(String.escaped s_args)
      in
      c.col_name^" "^t^
      (match c.col_nullable with
	true -> ""
      |	false -> " not null")^
      (match cdbms.col_key with
	None -> ""
      |	Some Primary_key -> " primary key"
      |	Some Key -> " key"
      )^
      (match cdbms.col_default with
	None -> ""
      |	Some v -> " default "^v)	

    method gen_exec fmt =
      p fmt "%s\n\n"
	("      try\n"^
	 "        let res = Mysql.exec db query in\n"^
         "        match Mysql.errmsg db with\n"^
         "          None -> ()\n"^
	 "        | Some s ->\n"^
	 "            raise (Failure ((mExecError query)^\" \"^s))\n"^
	 "      with\n"^
	 "      | Error s -> raise (Failure ((mExecError query)^\" \"^s))"
	);

    method gen_select_exec fmt table =
      p fmt "%s"
	("      try\n"^
	 "        let res = Mysql.exec db query in\n");
      self#gen_debug fmt "prerr_endline \"res = Mysql.exec db OK, let's see errmsg...\";";
      p fmt "%s"                   
        ("        match Mysql.errmsg db with\n"^
	 "          None ->\n");
      self#gen_debug fmt "prerr_endline \"errmsg=None\";";
      p fmt "%s"                   
        ("            let rec loop acc =\n"^
	 "              match Mysql.fetch res with\n"^
	 "                None ->\n");
      self#gen_debug fmt "prerr_endline \"fetch returns None\";";
      p fmt "%s"                   
        ("                   List.rev acc\n"^
	 "              | Some a ->\n");
      self#gen_debug fmt "prerr_endline \"fetch returns Some\";";
      p fmt "%s"                   
        (
	 "                 let t = {\n"
        );
      let rec iter n cols =
	match cols with
	  [] -> ()
	| col :: q ->
	    p fmt "                    %s = %s %s a.(%d);\n"
	      (String.lowercase col.col_name)
              (if col.col_nullable then "apply_opt" else "not_null")
	      (self#col_2ml col)
	      n;
	    iter (n+1) q
      in
      iter 0 table.ta_columns;
      
      p fmt "%s"
	("                   }\n"^
	 "                 in\n"^
	 "                 loop (t::acc)\n"^
	 "            in\n"^
	 "            loop []\n"^
	 "        | Some s ->\n");
      self#gen_debug fmt "prerr_endline \"errmsg=Some\";";
      p fmt "%s"                   
        ("            raise (Failure ((mExecError query)^\" \"^s))\n"^
	 "      with\n"^
	 "      | Error s -> raise (Failure ((mExecError query)^\" \"^s))\n\n"
	)



  end
