(*
#directory "../common";;
#load "../common/trampoline.cmo";;
#load "../common/namedTerm.cmo";;
#use "09a_machine.ml";;
*)

(* The functional notation is changed to a term-rewriting one.

   There are four variants of configurations: 
 *     "E(t,e,s,m)" corresponds to a call of "eval  m t e s",
 *     "C(s,v,m)"   corresponds to a call of "continue1 m s v",
 *     "M(t,c,s,v,m)"   corresponds to a call of "cached_reify t c m v s",
 *     "S(s,t,m)"   corresponds to a call of "continue2 m s t". *)


open NamedTerm

type 'a cache = 'a option ref

module Dict = Map.Make(struct type t = identifier let compare = compare end)

type value = Abs of identifier * n_term * env | Inert of inert
  | Cache of n_term cache * value
and inert = V of identifier | IApp of inert * value
  | ICache of n_term cache * inert
and env = value Dict.t

let rec env_lookup (x : identifier) (e : env) : value =
  match Dict.find_opt x e with
  | Some v -> v
  | None   -> Inert (V (x ^ "_free"))

type frame = Lapp of n_term * env | Rapp of value
  | LAM of identifier | LAPP of inert | RAPP of n_term | CACHE of n_term cache
type stack = frame list

type conf =
  | E of n_term * env * stack                         * int
  | C of stack * value                                * int
  | M of n_term option * n_term cache * stack * value * int
  | S of stack * n_term                               * int

let trans : conf -> conf option =
  function
  | S([], t, _) -> None
  | c -> Some (
  match c with
  | S(           [], _, _) -> assert false
  | S(Lapp(_, _)::_, _, _) -> assert false
  | S(    Rapp _::_, _, _) -> assert false
  | E(NApp(t1, t2), e, s1, m) -> E(t2, e, Lapp(t1, e)::s1, m) (* 1 *)
  | E( NLam(x, t'), e, s1, m) -> C(s1,  Abs(x, t', e), m)     (* 2 *)
  | E(      NVar x, e, s1, m) -> C(s1, env_lookup x e, m)     (* 3 *)
  | C(Lapp(t1, e)::s1, v, m) -> E(t1, e, Rapp v::s1, m)       (* 4 *)
  | C(Rapp (Cache(_, _) as v2)::s1, Abs(x, t', e), m) -> E(t', Dict.add x v2 e, s1, m) (* 5 *)
  | C(Rapp v2::s1,          (Abs(_, _, _) as xte), m) -> C(Rapp (Cache(ref None, v2))::s1, xte, m) (* 6 *)
  | C(Rapp v2::s1,           Inert i, m) -> C(s1, Inert(IApp(           i, v2)), m) (* 8 *)
  | C(Rapp v2::s1, Cache(c, Inert i), m) -> C(s1, Inert(IApp(ICache(c, i), v2)), m) (* 8C *)
  | C((Rapp _::_ as s1), Cache(_, v), m) -> C(s1, v, m) (* 7 *)
  | C(s2, Abs(x, t', e), m) -> let xm = x ^ "_" ^ string_of_int m in
                               E(t', Dict.add x (Cache (ref None, Inert (V xm))) e, LAM xm::s2, m+1) (* 9 *)
  | C(s2, Inert(         V x), m) -> S(s2, NVar x, m)          (* 10 *)
  | C(s2, Inert(  IApp(i, v)), m) -> C(LAPP i::s2, v, m)       (* 11 *)
  | C(s2, Inert(ICache(c, i)), m) -> M(!c, c, s2, Inert i, m)  (* 12I *)
  | C(s2,        Cache(c, v) , m) -> M(!c, c, s2, v, m)        (* 12W *)
  | M(Some y, _, s2, _, m) -> S(         s2, y, m)             (* 13 *)
  | M(  None, c, s2, v, m) -> C(CACHE c::s2, v, m)             (* 14 *)
  | S(CACHE c::s2, t, m) -> c := Some t; S(s2, t, m)           (* 15 *)
  | S( LAPP i::s2, t, m) -> C(RAPP t::s2, Inert i, m)          (* 16 *)
  | S(RAPP t2::s2, t, m) -> S(s2, NApp(t, t2), m)              (* 17 *)
  | S( LAM xm::s2, t, m) -> S(s2, NLam(xm, t), m))             (* 18 *)

let load (t : n_term) : conf = E(t, Dict.empty, [], 0)

let unload : conf -> n_term =
  function
  | S([], t, _) -> t
  | _ -> assert false

open Trampoline

let nf_trampoline (t:n_term) : n_term trampolined =
  trampolined_map unload @@ opt_trampoline trans @@ load t

let normal_form (t : n_term) : n_term = pogo_stick @@ nf_trampoline t

let steps_for_family (ts:int -> n_term) (n:int) : int = count_bounces @@ nf_trampoline @@ ts n

let _ = [
  List.map (steps_for_family test1) @@ (ints 0 5 @ [9]) = [23; 38; 69; 124; 227; 6214];
  List.map (steps_for_family explode) @@ (ints 0 5 @ [9]) = [19; 34; 53; 72; 91; 186];
  List.map (steps_for_family explode2) @@ (ints 0 9) = [19; 37; 59; 81; 103; 125; 147; 169; 191];
  List.map (steps_for_family explode3) @@ (ints 0 9) = [19; 48; 81; 114; 147; 180; 213; 246; 279];
  List.map (steps_for_family explode4) @@ (ints 0 9) = [23; 52; 85; 118; 151; 184; 217; 250; 283];
] @ cbv_tests normal_form

