(*
#directory "../common";;
#load "../common/trampoline.cmo";;
#load "../common/namedTerm.cmo";;
#use "09b_machine_alt.ml";;
*)

(* This version of the machine strengthens shape invariants of "09a_machine.ml". *)

open NamedTerm

type 'a cache = 'a option ref

module Dict = Map.Make(struct type t = identifier let compare = compare end)

type wnf = Abs of identifier * n_term * env | Inert of inert
  | Cache of normal_n_term cache * wnf
and inert = V of identifier | IApp of inert * wnf
  | ICache of normal_n_term cache * inert
and env = wnf Dict.t

let rec env_lookup (x : identifier) (e : env) : wnf =
  match Dict.find_opt x e with
  | Some v -> v
  | None   -> Inert (V (x ^ "_free"))

type stack1 = Lapp of n_term * env * stack1 | Rapp of wnf * stack1 | Reify of stack3
 and stack2 = RAPP of normal_n_term * stack2 | NEUT of stack3 | ICACHE of normal_n_term cache * stack2
 and stack3 = LAPP of inert * stack2 | LAM of identifier * stack3 | ID | CACHE of normal_n_term cache * stack3

type conf =
  | E of n_term * env * stack1                                         * int
  | CW of stack1 * wnf                                                 * int
  | CI of stack2 * inert                                               * int
  | MW of  normal_n_term option * normal_n_term cache * stack3 * wnf   * int
  | MI of neutral_n_term option * normal_n_term cache * stack2 * inert * int
  | SI of stack2 * neutral_n_term                                      * int
  | SW of stack3 * normal_n_term                                       * int

let peel_neut : normal_n_term -> neutral_n_term =
  function ENNeut a -> a | _ -> assert false

let trans : conf -> conf option =
  function
  | SW(ID, _, _) -> None
  | c -> Some (
  match c with
  | SW(ID, _, _) -> assert false
  | E(NApp(t1, t2), e, s1, m) -> E(t2, e, Lapp(t1, e, s1), m) (* 1 *)
  | E( NLam(x, t'), e, s1, m) -> CW(s1,  Abs(x, t', e), m)    (* 2 *)
  | E(      NVar x, e, s1, m) -> CW(s1, env_lookup x e, m)    (* 3 *)
  | CW(Lapp(t1, e, s1), v, m) -> E(t1, e, Rapp(v, s1), m)      (* 4 *)
  | CW(Rapp (Cache(_, _) as v2, s1), Abs(x, t', e), m) -> E(t', Dict.add x v2 e, s1, m) (* 5 *)
  | CW(Rapp (v2, s1),          (Abs(_, _, _) as xte), m) -> CW(Rapp (Cache(ref None, v2), s1), xte, m) (* 6 *)
  | CW(Rapp (v2, s1),           Inert i, m) -> CW(s1, Inert(IApp(           i, v2)), m) (* 8 *)
  | CW(Rapp (v2, s1), Cache(c, Inert i), m) -> CW(s1, Inert(IApp(ICache(c, i), v2)), m) (* 8C *)
  | CW((Rapp(_,_) as s1), Cache(_, v), m) -> CW(s1, v, m)                               (* 7 *)
  | CW(Reify s3, Abs(x, t', e), m) -> let xm = x ^ "_" ^ string_of_int m in
                               E(t', Dict.add x (Cache (ref None, Inert (V xm))) e, Reify (LAM(xm, s3)), m+1) (* 9 *)
  | CI(s2,          V x, m) -> SI(s2, ENVar x, m)                 (* 10 *)
  | CI(s2,   IApp(i, v), m) -> CW(Reify (LAPP (i, s2)), v, m)     (* 11 *)
  | CI(s2, ICache(c, i), m) -> MI(Option.map peel_neut (!c), c, s2, i, m)  (* 12I *)
  | CW(Reify s3, Cache(c, v) , m) -> MW(!c, c, s3, v, m)          (* 12W *)
  | MW(Some y, _, s3, _, m) -> SW(         s3, y, m)              (* 13W *)
  | MI(Some y, _, s2, _, m) -> SI(         s2, y, m)              (* 13I *)
  | MW(  None, c, s3, v, m) -> CW(Reify(CACHE(c, s3)), v, m)      (* 14W *)
  | MI(  None, c, s2, v, m) -> CI(ICACHE(c, s2), v, m)            (* 14I *)
  | SW( CACHE(c, s3), t, m) -> c := Some t; SW(s3, t, m)          (* 15W *)
  | SI(ICACHE(c, s2), t, m) -> c := Some (ENNeut t); SI(s2, t, m) (* 15I *)
  | SW( LAPP(i, s2), t, m) -> CI(RAPP(t, s2), i, m)               (* 16 *)
  | SI(RAPP(t2, s2), t, m) -> SI(s2, ENApp(t, t2), m)             (* 17 *)
  | SW( LAM(xm, s3), t, m) -> SW(s3, ENLam(xm, t), m)             (* 18 *)
  
  | CW(Reify s3, Inert i, m) -> CI(NEUT s3, i, m)
  | SI(NEUT s3, t, m) -> SW(s3, ENNeut t, m))
  

let load (t : n_term) : conf = E(t, Dict.empty, Reify ID, 0)

let unload : conf -> normal_n_term =
  function
  | SW(ID, t, _) -> t
  | _ -> assert false

open Trampoline

let nf_trampoline (t : n_term) : normal_n_term trampolined =
  trampolined_map unload @@ opt_trampoline trans @@ load t

let normal_form (t : n_term) : n_term = n_term_of_normal @@ pogo_stick @@ nf_trampoline t

let steps_for_family (ts:int -> n_term) (n:int) : int = count_bounces @@ nf_trampoline @@ ts n

let _ = cbv_tests normal_form

