(*
#directory "../common";;
#load "../common/trampoline.cmo";;
#load "../common/namedTerm.cmo";;
#use "11_potential.ml";;
*)

open NamedTerm

type location = int
module Heap = Map.Make(struct type t = location   let compare = compare end)
type heap = n_term option Heap.t

let heap_alloc (h : heap) : location * heap =
  let l = Heap.cardinal h in
  (l, Heap.add l None h)

module Dict = Map.Make(struct type t = identifier let compare = compare end)

type value = Abs of identifier * n_term * env
  | V of identifier | IApp of value * value
  | Cache of location * value
and env = value Dict.t

let rec env_lookup (x : identifier) (e : env) : value =
  match Dict.find_opt x e with
  | Some v -> v
  | None   -> V (x ^ "_free")

type frame = Lapp of n_term * env | Rapp of value
  | LAM of identifier | LAPP of value | RAPP of n_term | CACHE of location
type stack = frame list

type conf =
  | E of n_term * env * stack                     * int * heap
  | C of stack * value                            * int * heap
  | M of n_term option * location * stack * value * int * heap
  | S of stack * n_term                           * int * heap

let trans_extra : conf -> (conf * int) option =
  function
  | S([], t, _, _) -> None
  | c -> Some (
  match c with
  | S(           [], _, _, _) -> assert false
  | S(Lapp(_, _)::_, _, _, _) -> assert false
  | S(    Rapp _::_, _, _, _) -> assert false
  | E(NApp(t1, t2), e, s1, m, h) -> (E(t2, e, Lapp(t1, e)::s1, m, h), 1)
  | E( NLam(x, t), e, s1, m, h)  -> (C(s1,  Abs(x, t, e), m, h), 2)
  | E(      NVar x, e, s1, m, h) -> (C(s1, env_lookup x e, m, h), 3)
  | C(Lapp(t1, e)::s1, v, m, h) -> (E(t1, e, Rapp v::s1, m, h), 4)
  | C(Rapp (Cache(_, _) as v2)::s1, Abs(x, t, e), m, h)  -> (E(t, Dict.add x v2 e, s1, m, h), 5)
  | C(Rapp v2::s1,          (Abs(_, _, _) as xte), m, h) -> (let l, h' = heap_alloc h in
                                                            C(Rapp (Cache(l, v2))::s1, xte, m, h'), 6)
  | C((Rapp _::_ as s1), Cache(_, (Abs(_, _, _) as v)), m, h) -> (C(s1, v, m, h), 7)
  | C(Rapp v2::s1,                                   i, m, h) -> (C(s1, IApp(         i , v2) , m, h), 8)
  | C(s2, Abs(x, t, e), m, h) -> (let xm = x ^ "_" ^ string_of_int m in
                                 let l, h' = heap_alloc h in
                                 E(t, Dict.add x (Cache (l, V xm)) e, LAM xm::s2, m+1, h'), 9)
  | C(s2,         V x, m, h) -> (S(s2, NVar x, m, h), 10)
  | C(s2,  IApp(i, v), m, h) -> (C(LAPP i::s2, v, m, h), 11)
  | C(s2, Cache(l, v), m, h) -> (M(Heap.find l h, l, s2, v, m, h), 12)
  | M(Some y, _, s2, _, m, h) -> (S(         s2, y, m, h), 13)
  | M(  None, l, s2, v, m, h) -> (C(CACHE l::s2, v, m, h), 14)
  | S(CACHE l::s2, t, m, h) -> (S(s2, t, m, Heap.add l (Some t) h), 15)
  | S( LAPP i::s2, t, m, h) -> (C(RAPP t::s2, i, m, h), 16)
  | S(RAPP t2::s2, t, m, h) -> (S(s2, NApp(t, t2), m, h), 17)
  | S( LAM xm::s2, t, m, h) -> (S(s2, NLam(xm, t), m, h)), 18)

open Trampoline

let trans (c : conf) : conf option = opt_map Pervasives.fst @@ trans_extra c

let load (t : n_term) : conf = E(t, Dict.empty, [], 0, Heap.empty)

let unload : conf -> n_term =
  function
  | S([], t, _, _) -> t
  | _ -> assert false

let nf_trampoline (t:n_term) : n_term trampolined =
  trampolined_map unload @@ opt_trampoline trans @@ load t

let normal_form (t : n_term) : n_term = pogo_stick @@ nf_trampoline t

let steps_for_family (ts:int -> n_term) (n:int) : int = count_bounces @@ nf_trampoline @@ ts n

let ith_conf_of (i:int) (t:n_term) : conf = iter i (opt_try_step trans) (load t)

let _ = [
  List.map (steps_for_family test1) @@ (ints 0 5 @ [9]) = [23; 38; 69; 124; 227; 6214];
  List.map (steps_for_family explode) @@ (ints 0 5 @ [9]) = [19; 34; 53; 72; 91; 186];
  List.map (steps_for_family explode2) @@ (ints 0 9) = [19; 37; 59; 81; 103; 125; 147; 169; 191];
  List.map (steps_for_family explode3) @@ (ints 0 9) = [19; 48; 81; 114; 147; 180; 213; 246; 279];
  List.map (steps_for_family explode4) @@ (ints 0 9) = [23; 52; 85; 118; 151; 184; 217; 250; 283];
] @ cbv_tests normal_form

let int_of_bool (b:bool) = if b then 1 else 0

let rec psi_t : n_term -> int = function
  | NApp(t1, t2) -> 6 + psi_t t1 + psi_t t2
  | NLam(_, t)   -> 4 + psi_t t
  | NVar _       -> 4

let rec psi_v : value -> int = function
  | IApp(v1, v2) -> 3 + psi_v v1 + psi_v v2
  | Abs(_, t, _) -> 3 + psi_t t
  | V _          -> 1
  | Cache(_,_)   -> 3

let psi_f : frame -> int = function
  | Lapp(t1, _) -> 5 + psi_t t1
  | Rapp v2     -> 4 + psi_v v2
  | LAPP v1     -> 2 + psi_v v1
  | RAPP _      -> 1
  | LAM _       -> 1
  | CACHE _     -> 1

let psi_s (s : stack) : int =
  List.fold_left (+) 0 (List.map psi_f s)

type val_heap = value Heap.t

let rec add_cached (l : location) (v : value) (vs : val_heap) : val_heap =
  match Heap.find_opt l vs with
  | Some v' -> assert (v == v'); vs
  | None    -> Heap.add l v vs

let rec add_cached_of_v (v : value) (vs : val_heap) : val_heap =
  match v with
  | IApp(v1, v2) -> add_cached_of_v v1 (add_cached_of_v v2 vs)
  | Abs(_, _, e) -> add_cached_of_e e vs
  | V _          -> vs
  | Cache(l, v)  -> add_cached_of_v v (add_cached l v vs)
and add_cached_of_e (e : env) (vs : val_heap) : val_heap =
  Dict.fold (fun _ -> add_cached_of_v) e vs

let add_cached_of_f (f : frame) (vs : val_heap) : val_heap =
  match f with
  | Lapp(_, e) -> add_cached_of_e e vs
  | Rapp v2    -> add_cached_of_v v2 vs
  | LAPP v1    -> add_cached_of_v v1 vs
  | _          -> vs

let add_cached_of_s (s : stack) (vs : val_heap) : val_heap =
  List.fold_right add_cached_of_f s vs

let remove_active_of_s (s : stack) (h : heap) : heap =
  List.fold_left
    (fun h f -> match f with CACHE l -> Heap.remove l h | _ -> h) h s

let psi_h (s : stack) (h : heap) (vs : val_heap) : int =
  Heap.fold (fun l o n ->
    match o with
    | Some _ -> n
    | None   -> n + (Heap.find_opt l vs |> opt_map psi_v |> opt_else 0))
  (remove_active_of_s s h)
  0

let fits_beta_rule (c : conf) : int =
  match trans_extra c with
  | Some (_, 5) -> 1
  | _           -> 0

let psi_k (c:conf) : int =
  match c with
  | E(t, e, s, _, h) -> psi_t t + psi_s s + psi_h s h (Heap.empty
                        |> add_cached_of_e e |> add_cached_of_s s)
  | C(s, v, _, h)    -> psi_s s + psi_v v - 9 * fits_beta_rule c + psi_h s h (Heap.empty
                        |> add_cached_of_v v |> add_cached_of_s s)
  | S(s, _, _, h)    -> psi_s s + psi_h s h (Heap.empty
                        |> add_cached_of_s s)
  | M(_, l, s, v, _, h) -> 2 + psi_s s + psi_h s h (Heap.empty
                        |> add_cached l v |> add_cached_of_v v |> add_cached_of_s s)

let mode : conf -> int = function
  | E(_, _, _, _, _)    ->  0
  | C(_, _, _, _)       -> 10
  | M(_, _, _, _, _, _) -> 20
  | S(_, _, _, _)       -> 30

let write_report (title : string) (t : n_term) : unit =
  let oc = open_out (title ^ ".csv") in
  let rec aux (i : int) (c : conf) : unit =
    Printf.fprintf oc "%d; %d; %d" i (psi_k c) (mode c);
    match trans_extra c with
    | None         -> ()
    | Some (c', r) -> assert (psi_k c > psi_k c' || r = 7);
                      Printf.fprintf oc "; %d\n" r;
                      aux (i+1) c'
  in aux 0 (load t);
  close_out oc

let _ = write_report "explode4_6" (explode4 6)
let _ = write_report "implode2_6" (implode2 6) 
let _ = write_report "test1_6"    (test1 6)

