(**************************************************************************)
(*                   Cameleon                                             *)
(*                                                                        *)
(*      Copyright (C) 2002 Institut National de Recherche en Informatique et   *)
(*      en Automatique. All rights reserved.                              *)
(*                                                                        *)
(*      This program is free software; you can redistribute it and/or modify  *)
(*      it under the terms of the GNU General Public License as published by  *)
(*      the Free Software Foundation; either version 2 of the License, or  *)
(*      any later version.                                                *)
(*                                                                        *)
(*      This program is distributed in the hope that it will be useful,   *)
(*      but WITHOUT ANY WARRANTY; without even the implied warranty of    *)
(*      MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the     *)
(*      GNU General Public License for more details.                      *)
(*                                                                        *)
(*      You should have received a copy of the GNU General Public License  *)
(*      along with this program; if not, write to the Free Software       *)
(*      Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA          *)
(*      02111-1307  USA                                                   *)
(*                                                                        *)
(*      Contact: Maxence.Guesdon@inria.fr                                *)
(**************************************************************************)

(** The class for ODBC databases through the OCamlODBC library. *)

open Dbf_types.Current
module M = Dbf_messages

let int = "integer"
let float = "float"
let double = "double"
let varchar_n = "varchar(n)"
let other = "other"

class odbc_spec : Dbf_dbms.dbms_spec =
  object
    method dbms = Odbc
    method name = "ODBC"
	
    method types = [ (int, None) ;
		     (float, None) ;
		     (double, None) ; 
		     (varchar_n, Some "n") ;
		     (other, Some "") ;
		   ] 
	
    method header = 
"open Ocamlodbc

let mFormatError req = \""^Dbf_messages.mFormatError^" \"^req
let mExecError req = \""^Dbf_messages.mExecError^" \"^req

let string_of_pred_list l =
  let rec iter acc = function
      [] -> acc
    | (c,Some v) :: [] -> acc^c^\"=\"^v
    | (c,None) :: [] -> acc^c^\" IS NULL\"
    | (c,Some v) :: q -> iter (acc^c^\"=\"^v^\" AND \") q
    | (c,None) :: q -> iter (acc^c^\" IS NULL AND \") q
  in
  iter \"\" l\n\n

let escape_string s =
  let rec iter acc s =
    let len = String.length s in
    if len = 0 then
      acc
    else
      match s.[0] with
        '\\'' -> iter (acc ^ \"\\\\'\") (String.sub s 1 (len -1))
      | '\\\\' -> iter (acc ^ \"\\\\\\\\\") (String.sub s 1 (len -1))
      | _ -> iter (acc ^ (String.sub s 0 1)) (String.sub s 1 (len -1))
  in
  iter \"\" s
    
let string_of_sqlstring s = s

let sqlstring_of_string s =  \"'\"^(escape_string s)^\"'\"

(* Return an optional string from a string returned by ODBC
   ([None] means that the given string is [\"NULL\"]*)
let string_opt s =
   match s with
     \"NULL\" -> None
   | _ -> Some s

(* Return a the string [\"NULL\"] if None or the given string if Some.*)
let string_or_null s_opt =
  match s_opt with
    None -> \"NULL\"
  | Some s -> s

(* Apply a function to an optional value. *)
let apply_opt f v_opt =
   match v_opt with
     None -> None
   | Some v -> Some (f v)

(* Generic update function used in the [update] function of each table.*)
let update_table db table
    pred_list set_list =
  let query = \"update \"^table^\" set \"^
      (String.concat \", \" set_list)^
	 (match pred_list with
	   [] -> \"\"
	 | _ -> \" where \"^(String.concat \" AND \" pred_list))
  in
  try
    let (n,l) = Ocamlodbc.execute db query in
    if n = 0 then
      ()
    else
      raise (Failure (mExecError query))
  with
  | Failure s -> raise (Failure s)


type db = Ocamlodbc.connection

let connect
    ?(host : string option) 
    ?(port : int option) 
    ?password user database =
  try
   Ocamlodbc.connect database user
     (match password with None -> \"\" | Some p -> p)
  with
    Ocamlodbc.SQL_Error s -> raise (Failure s)

let disconnect db = Ocamlodbc.disconnect db

"

    method col_attributes = []

    method col_keys = ([Primary_key ; Key] : t_key list)

    method funs_2ml = [ "int_of_string" ; "float_of_string" ;
			"string_of_sqlstring"]
    method funs_ml2 = [ "string_of_int" ; "string_of_float" ;
			"sqlstring_of_string"]
  end


let spec = ((new odbc_spec) :> Dbf_dbms.dbms_spec)

let remove_table_prefix s =
  let len = String.length s in
  let lenp = String.length !Dbf_args.remove_table_prefix in
  if len <= lenp then
    s
  else
   (if String.lowercase (String.sub s 0 lenp) = 
       String.lowercase !Dbf_args.remove_table_prefix
    then
      String.sub s lenp (len - lenp)
    else
      s
   )


let p = Format.fprintf

(** Code generation to use OCamlODBC. *)
class odbc_gen spec =
  object (self)

    (** {2 Implementation generation} *)

    method generate fmt schema =
      Format.fprintf fmt "%s\n" spec#header;
      Dbf_misc.sort_tables schema;
      List.iter (self#gen_table fmt) schema.sch_tables;
      Format.pp_print_flush fmt ()

    method gen_table fmt table =
      Format.fprintf fmt "module %s =\n  struct\n" 
	(String.capitalize (remove_table_prefix table.ta_name));
      self#gen_type fmt table;
      self#gen_create fmt table;
      self#gen_select fmt table;
      self#gen_insert fmt table;
      self#gen_update fmt table;
      self#gen_delete fmt table;
      self#gen_drop fmt table;
      Format.fprintf fmt "  end\n\n"

    method col_ml2 c = (List.assoc spec#dbms c.col_dbms).col_ml2
    method col_2ml c = (List.assoc spec#dbms c.col_dbms).col_2ml
      
    (** Get the SQL code to define the given column. *)
    method column_def c =
      let cdbms = List.assoc spec#dbms c.col_dbms in
      let t = 
	let (s, v_opt, args_opt) = cdbms.col_type_sql in
	let s_args = match args_opt with None -> "" | Some a -> a in
        let maybe_int_code s =
          try ignore (int_of_string s); s
          with _ -> Printf.sprintf "\"^(string_of_int (%s))^\"" s
        in	
        match v_opt with
	  _ when s = int -> "integer "^(String.escaped s_args)
	| _ when s = float -> "float "^(String.escaped s_args)
	| _ when s = double -> "double "^(String.escaped s_args)
	| Some n when s = varchar_n -> 
           "varchar("^(maybe_int_code n)^")" ^
           (String.escaped s_args)
	| Some o when s = other -> o^" "^(String.escaped s_args)
	| _ -> raise (Failure (Format.sprintf "%s: %s %s" 
				 M.incorrect_type_definition
				 s (match v_opt with None -> "<None>" | Some n -> n)))
      in
      c.col_name^" "^t^
      (match c.col_nullable with
	true -> ""
      |	false -> " not null")^
      (match cdbms.col_key with
	None -> ""
      |	Some Primary_key -> " primary key"
      |	Some Key -> " key"
      )^
      (match cdbms.col_default with
	None -> ""
      |	Some v -> " default "^v)

    (** Output the code for the type representing a record of the given table. *)
    method gen_type fmt table =
      Dbf_misc.sort_columns table;
      p fmt "    type t = {\n";
      List.iter
	(fun c -> 
	  p fmt
            "      mutable %s : (%s) %s ; (** %s *)\n"
	    (String.lowercase c.col_name)
	    c.col_type_ml
	    (if c.col_nullable then "option" else "")
	    c.col_comment
	)
	table.ta_columns;
      p fmt "      }\n\n"

   (** Output the code executing the [query]. *)   
   method gen_exec fmt =
     p fmt "%s\n\n"
	("      try\n"^
	 "        let (n,l) = Ocamlodbc.execute db query in\n"^
         "        if n = 0 then\n"^
         "          ()\n"^
	 "        else\n"^
	 "          raise (Failure (mExecError query))\n"^
	 "      with\n"^
	 "      | Failure s -> raise (Failure s)"
	)

    (** Output debug code if debug mode is on. *)
    method gen_debug fmt code =
      if !Dbf_args.debug then p fmt "%s" code

    (** Output debug code to print [query] if debug mode is on. *)
    method gen_debug_query fmt =
      self#gen_debug fmt "prerr_endline query;\n"

    (** Output the header of the [create] function for the given table.*)
    method gen_create_header fmt table =
     p fmt "    let create db =\n"

    (** Output the code to define [query] in the [create] function for the given table.*)
    method gen_create_query fmt table =
      p fmt "      let query = \"create table %s (%s)\" in\n" 
	table.ta_name
	(String.concat ", "
	   (List.map (fun col -> self#column_def col) table.ta_columns))
      (*^table#extra_options A VOIR*)

    (** Output the code executing the [query] in the [create] function. *)
    method gen_create_exec = self#gen_exec

    (** Output the code for the function creating the given table. *)
    method gen_create fmt table =
      self#gen_create_header fmt table;
      self#gen_create_query fmt table;   
      self#gen_debug_query fmt;
      self#gen_create_exec fmt;
      (* A VOIR : index sur les colonnes, attributs de la table *)

    (** Output the header of the [select] function for the given table.*)
    method gen_select_header fmt table =
      p fmt "    let select db ?(distinct=false)\n        ";
      List.iter (fun col -> p fmt "?%s " (String.lowercase col.col_name)) table.ta_columns;
      p fmt "() =\n"

    (** Output the code to define [query] in the [select] function for the given table.*)
    method gen_select_query fmt table =
      (* the list of columns in the select query *)
      p fmt "      let columns = [";
      List.iter
	(fun c -> p fmt "\"%s\";" c.col_name)
	table.ta_columns;
      p fmt "] in\n";

      (* the analysis of the columns in the where clause *)
      p fmt "%s"
	("      let l_pred = List.filter\n"^
	 "          (fun v -> v <> None)\n"^
	 "          [\n"
	);

      List.iter
	 (fun col -> 
	   p fmt "            (match %s with\n" (String.lowercase col.col_name);
           p fmt "               None -> None\n";
	   p fmt "             | Some v -> Some (\"%s\", apply_opt %s %s));\n"
	     (String.lowercase col.col_name)
	     (self#col_ml2 col)
             (if col.col_nullable then "v" else "(Some v)")
	 )
	table.ta_columns;

      p fmt "%s"
	("          ]\n"^
	 "      in\n"^
	 "      let l_pred2 = List.map\n"^
	 "        (fun p -> match p with (Some pred) -> pred | None -> \"\",None) l_pred in\n"
	);

      p fmt "%s"
	("      let query = \"select \"^(if distinct then \"distinct\" else \"\")^\" \"^\n"^
	 "        (String.concat \", \" columns)^\n"
	);
      p fmt "         \" from %s \"^\n" table.ta_name ;
      p fmt "%s"
	("        (match l_pred with\n"^
	 "           [] -> \"\"\n"^
	 "         | _ -> \"where \"^(string_of_pred_list l_pred2))\n"^
	 "      in\n"
	)

    (** Output the code executing the [query] in the [select] function. *)
    method gen_select_exec fmt table =
      p fmt "%s"
	("      try\n"^
	 "        let (n,l) = Ocamlodbc.execute db query in\n"^
	 "        if n = 0 then\n"^
	 "          let f = function\n"
	);
      p fmt 
	"             | %s :: [] ->\n"  
	(String.concat " :: "
           (List.map (fun c -> "v_"^(String.lowercase c.col_name)) table.ta_columns));
      p fmt 
	"               {\n";
      List.iter 
	(fun col ->
	  p fmt "                  %s = %s%s (%sv_%s);\n"
	    (String.lowercase col.col_name)
            (if col.col_nullable then "apply_opt " else "")
	    (self#col_2ml col)
            (if col.col_nullable then "string_opt " else "")
	    (String.lowercase col.col_name)
	)
	table.ta_columns;
      
      p fmt "%s"
	("               }\n"^
	 "            | _ -> raise (Failure (mFormatError query))\n"^
	 "          in\n"^
	 "          List.map f l\n"^
	 "        else\n"^
	 "          raise (Failure (mExecError query))\n"^
	 "      with\n"^
	 "      | Failure s -> raise (Failure s)\n\n"
	)


    (** Output the code of the [select] function for the given table. *)
    method gen_select fmt table =
       (* the columns always appear in alphabetical order in the parameters list *)
      Dbf_misc.sort_columns table;
      (* output the header of the function *)
      self#gen_select_header fmt table;
      (* output the creation of the query *)
      self#gen_select_query fmt table;
      self#gen_debug_query fmt;
      (* output the execution of query *)
      self#gen_select_exec fmt table ;

    (** Output the header of the [insert] function for the given table.*)
    method gen_insert_header fmt table =
      p fmt "    let insert db ";
      List.iter (fun col -> p fmt "?%s " (String.lowercase col.col_name)) table.ta_columns;
      p fmt "() =\n";

    (** Output the code to define [query] in the [insert] function for the given table.*)
    method gen_insert_query fmt table =
      (* the analysis of the given values *)
      p fmt "%s"
	("      let l_val = List.filter\n"^
	 "          (fun v -> v <> None)\n"^
	 "          [\n"
	);
      List.iter
	(fun col -> 
	  p fmt "            (match %s with\n" (String.lowercase col.col_name);
	  p fmt "             | None -> None\n";
          p fmt "             | Some v -> Some (\"%s\", apply_opt %s %s)) ;\n" 
	    (String.lowercase col.col_name)
	    (self#col_ml2 col)
            (if col.col_nullable then "v" else "(Some v)");
	)
	table.ta_columns;
      p fmt "%s"
	("          ]\n"^
	 "      in\n"^
	 "      let l_val2 = List.map (fun v -> match v with (Some cpl) -> cpl | None -> \"\",None) l_val in\n"
	);
      (* output the creation of the query *)
      p fmt "      let query = \"insert into %s \"^\n" table.ta_name;
      p fmt "%s"
        ("        \"(\"^(String.concat \", \" (List.map fst l_val2))^\")\"^\n"^
	 "        \"values (\"^(String.concat \", \" "^
	 "                      (List.map (fun (_,s_opt) -> string_or_null s_opt) l_val2))^\")\"\n"^
	 "      in\n"
	);

    (** Output the code executing the [query] in the [insert] function. *)
    method gen_insert_exec = self#gen_exec

    (** Output the code of [insert] function for the given table. *)
    method gen_insert fmt table =
      (* the columns always appear in alphabetical order in the parameters list *)
      Dbf_misc.sort_columns table;
      self#gen_insert_header fmt table;
      self#gen_insert_query fmt table;
      self#gen_debug_query fmt;
      self#gen_insert_exec fmt

    (** Output the header of the [update] function for the given table.*)
    method gen_update_header fmt table =
      p fmt "    let update db ";
      List.iter (fun col -> p fmt "?%s " (String.lowercase col.col_name)) table.ta_columns;
      p fmt "\n        ";
      List.iter (fun col -> p fmt "?key_%s " (String.lowercase col.col_name)) table.ta_columns;
      p fmt "() =\n";

    (** Output the code of the [update] function for the given table. *)
    method gen_update_body fmt table = 
      (* the analyse of the given set-values *)
      p fmt "%s"
	("      let l_set = List.filter\n"^
	 "          (fun v -> v <> None)\n"^
	 "          [\n"
	);
      List.iter
	(fun col -> 
	  p fmt "            (match %s with\n" (String.lowercase col.col_name);
	  p fmt "             | None -> None\n";
          p fmt "             | Some v -> Some (\"%s=\"^(string_or_null (apply_opt %s %s)))) ;\n" 
	    (String.lowercase col.col_name)
	    (self#col_ml2 col)
	    (if col.col_nullable then "v" else "(Some v)");
	)
	table.ta_columns;
      p fmt "%s"
	("          ]\n"^
	 "      in\n"^
	 "      let l_set2 = List.map (fun v -> match v with (Some cpl) -> cpl | None -> \"\") l_set in\n"
	);
      (* the analysis of the given pred-values *)
      p fmt "%s"
	("      let l_pred = List.filter\n"^
	 "          (fun v -> v <> None)\n"^
	 "          [\n"
	);
      List.iter
	(fun col -> 
	  p fmt "            (match key_%s with\n" (String.lowercase col.col_name);
          p fmt "             | None -> None\n";
	  p fmt "             | Some v -> Some (\"%s\"^(match apply_opt %s %s with None -> \"IS NULL\" | Some s -> \"=\"^s))) ;\n"
	    (String.lowercase col.col_name)
	    (self#col_ml2 col)
	    (if col.col_nullable then "v" else "(Some v)");
	)
	table.ta_columns;
      p fmt "%s"
	("          ]\n"^
	 "      in\n"^
	 "      let l_pred2 = List.map (fun v -> match v with (Some cpl) -> cpl | None -> \"\") l_pred in\n"
	);

      (* output the call to the generic update function *)
      p fmt "      update_table db \"%s\" l_pred2 l_set2 \n\n" table.ta_name;
      ()
      
    (** Output code of the [update] function of the given table. *)
    method gen_update fmt table =
      (* the columns always appear in alphabetical order in the parameters list *)
      Dbf_misc.sort_columns table;
      self#gen_update_header fmt table;
      self#gen_update_body fmt table

    (** Output the header of the [delete] function for the given table.*)
    method gen_delete_header fmt table =
      p fmt "    let delete db ";
      List.iter (fun col -> p fmt "?%s " (String.lowercase col.col_name)) table.ta_columns;
      p fmt "() =\n";

    (** Output the code to define [query] in the [delete] function for the given table.*)
    method gen_delete_query fmt table =
      (* the analyse of the columns in the where clause *)
      p fmt "%s"
	("      let l_pred = List.filter\n"^
	 "          (fun v -> v <> None)\n"^
	 "          [\n"
	);
      List.iter
	(fun col -> 
	  p fmt "            (match %s with\n" (String.lowercase col.col_name);
	  p fmt "             | None -> None\n";
          p fmt "             | Some v -> Some (\"%s\", apply_opt %s %s)) ;\n" 
	    (String.lowercase col.col_name)
	    (self#col_ml2 col)
	    (if col.col_nullable then "v" else "(Some v)");
	)
	table.ta_columns;
      p fmt "%s"
	("          ]\n"^
	 "      in\n"^
	 "      let l_pred2 ="^
	 "        List.map (fun p -> match p with (Some pred) -> pred | None -> \"\",None) l_pred"^
	 "      in\n"
	);
      (* output the creation of the query *)
      p fmt "      let query = \"delete from %s \"^\n" table.ta_name;
      p fmt "%s"
	("        (match l_pred with\n"^
	 "          [] ->\n"^
	 "             \"\"\n"^
	 "        | _ ->\n"^
	 "             \"where \"^(string_of_pred_list l_pred2))\n"^
	 "      in\n"
	);

    (** Output the code executing the [query] in the [delete] function. *)
    method gen_delete_exec = self#gen_exec

    (** Output the code of the [delete] function for the given table. *)
    method gen_delete fmt table =
      (* the columns always appear in alphabetical order in the parameters list *)
      Dbf_misc.sort_columns table;
      self#gen_delete_header fmt table;
      self#gen_delete_query fmt table;
      self#gen_debug_query fmt;
      self#gen_delete_exec fmt
	
    (** Output the header of the [drop] function for the given table.*)
    method gen_drop_header fmt table =
      p fmt "    let drop db =\n";

    (** Output the code to define [query] in the [drop] function for the given table.*)
    method gen_drop_query fmt table =
      p fmt "      let query = \"drop table %s\" in\n" table.ta_name;

    (** Output the code executing the [query] in the [drop] function. *)
    method gen_drop_exec = self#gen_exec

    (** Output the code executing the [query] in the [drop] function. *)
    method gen_drop_exec = self#gen_exec

    (** Output the code of the [drop] function for the given table.*)
    method gen_drop fmt table =
      self#gen_drop_header fmt table;
      self#gen_drop_query fmt table;
      self#gen_debug_query fmt;
      self#gen_drop_exec fmt

    (** {2 Interface generation} *)

    method i_generate fmt schema =
      Dbf_misc.sort_tables schema;
      p fmt "(** %s%s *)\n\n" 
	(match !Dbf_args.title with
	  None -> ""
	| Some t -> t^"\n\n")
	Dbf_messages.generated_by;
      p fmt "type db %s\n\n"
	(if !Dbf_args.db_manifest then 
	  match !Dbf_args.gen_code with
	    None -> assert false
	  | Some Odbc -> " = Ocamlodbc.connection" 
	  | Some Mysql -> " = Mysql.dbd"
	  | Some Postgres -> " = Postgres.connection"
	else
	  ""
	);
      p fmt "%s"
	("(** [connect user database] connects to [database] as [user].\n"^
	 "   @return a handle to the database connection.\n"^
	 "   @param host not used in OCamlODBC.\n"^
	 "   @param port not used in OCamlODBC.\n"^
	 "   @param password optional password, [\"\"] is used by default.\n"^
	 "   @raise Failure with a message if an error occurs.\n"^
	 "*)\n"^
	 "val connect :\n"^
	 "    ?host : string ->\n"^
	 "      ?port : int  ->\n"^
	 "	?password : string ->\n"^
	 "	  string -> string -> db\n\n"^
	 "\n"^
	 "(** Disconnect from database. *)\n"^
	 "val disconnect : db -> unit\n"^
	 "\n"^
	 "val string_of_sqlstring : string -> string\n"^
	 "val sqlstring_of_string : string -> string\n"
	);
      List.iter (self#i_gen_table fmt) schema.sch_tables

    method i_gen_table fmt table =
      Format.fprintf fmt "(** %s *)\n" table.ta_comment;
      Format.fprintf fmt "module %s :\n  sig\n" 
	(String.capitalize (remove_table_prefix table.ta_name));
      self#i_gen_type fmt table;
      self#i_gen_create fmt table;
      self#i_gen_select fmt table;
      self#i_gen_insert fmt table;
      self#i_gen_update fmt table;
      self#i_gen_delete fmt table;
      self#i_gen_drop fmt table;
      Format.fprintf fmt "  end\n\n"

    method i_gen_type = self#gen_type

    method i_gen_create fmt table =
      p fmt "  (** Create table %s.*)\n" table.ta_name;
      p fmt "  val create : db -> unit\n\n"
      
    method i_gen_select fmt table =
      p fmt "  (** Select records in table %s.*)\n" table.ta_name;
      p fmt "  val select : db -> ?distinct: bool -> \n";
      List.iter 
	(fun c -> 
	  p fmt "    ?%s : (%s)%s ->\n" 
	    (String.lowercase c.col_name)
	    c.col_type_ml
            (if c.col_nullable then " option" else "")
	)
	table.ta_columns;
      p fmt "    unit -> t list\n\n"

    method i_gen_insert fmt table =
      p fmt "  (** Insert a record in table %s.*)\n" table.ta_name;
      p fmt "  val insert : db -> \n";
      List.iter 
	(fun c -> 
	  p fmt "    ?%s : (%s)%s ->\n" 
	    (String.lowercase c.col_name)
	    c.col_type_ml
	    (if c.col_nullable then " option" else "")
	)
	table.ta_columns;
      p fmt "    unit -> unit\n\n"

    method i_gen_update fmt table =
      p fmt "  (** Update records in table %s.*)\n" table.ta_name;
      p fmt "  val update : db -> \n";
      List.iter 
	(fun c -> 
	  p fmt "    ?%s : (%s)%s ->\n" 
	    (String.lowercase c.col_name)
	    c.col_type_ml
	    (if c.col_nullable then " option" else "")
	)
	table.ta_columns;
      List.iter 
	(fun c -> 
	  p fmt "    ?key_%s : (%s)%s ->\n" 
	    (String.lowercase c.col_name)
	    c.col_type_ml
	    (if c.col_nullable then " option" else "")
	)
	table.ta_columns;
      p fmt "    unit -> unit\n\n"
     
    method i_gen_delete fmt table =
      p fmt "  (** Delete records from table %s.*)\n" table.ta_name;
      p fmt "  val delete : db -> \n";
      List.iter 
	(fun c -> 
	  p fmt "    ?%s : (%s)%s ->\n" 
	    (String.lowercase c.col_name)
	    c.col_type_ml
	    (if c.col_nullable then " option" else "")
	)
	table.ta_columns;
      p fmt "    unit -> unit\n\n"

    method i_gen_drop fmt table =
      p fmt "  (** Drop table %s.*)\n" table.ta_name;
      p fmt "  val drop : db -> unit\n\n"

  end




