(*
#directory "../common";;
#load "../common/trampoline.cmo";;
#load "../common/namedTerm.cmo";;
#use "10_heap.ml";;
*)

open NamedTerm

type location = int
module Heap = Map.Make(struct type t = location let compare = compare end)
type heap = n_term option Heap.t

let heap_alloc (h : heap) : location * heap =
  let l = Heap.cardinal h in
  (l, Heap.add l None h)

module Dict = Map.Make(struct type t = identifier let compare = compare end)

type value = Abs of identifier * n_term * env
  | V of identifier | IApp of value * value
  | Cache of location * value
and env = value Dict.t

let rec env_lookup (x : identifier) (e : env) : value =
  match Dict.find_opt x e with
  | Some v -> v
  | None   -> V (x ^ "_free")

type frame = Lapp of n_term * env | Rapp of value
  | LAM of identifier | LAPP of value | RAPP of n_term | CACHE of location
type stack = frame list

type conf =
  | E of n_term * env * stack                     * int * heap
  | C of stack * value                            * int * heap
  | M of n_term option * location * stack * value * int * heap
  | S of stack * n_term                           * int * heap

let trans : conf -> conf option =
  function
  | S([], t, _, _) -> None
  | c -> Some (
  match c with
  | S(           [], _, _, _) -> assert false
  | S(Lapp(_, _)::_, _, _, _) -> assert false
  | S(    Rapp _::_, _, _, _) -> assert false
  | E(NApp(t1, t2), e, s1, m, h) -> E(t2, e, Lapp(t1, e)::s1, m, h)
  | E( NLam(x, t), e, s1, m, h) -> C(s1,  Abs(x, t, e), m, h)
  | E(      NVar x, e, s1, m, h) -> C(s1, env_lookup x e, m, h)
  | C(Lapp(t1, e)::s1, v, m, h) -> E(t1, e, Rapp v::s1, m, h)
  | C(Rapp (Cache(_, _) as v2)::s1, Abs(x, t, e), m, h) -> E(t, Dict.add x v2 e, s1, m, h)
  | C(Rapp v2::s1,          (Abs(_, _, _) as xte), m, h) -> let l, h' = heap_alloc h in
                                                            C(Rapp (Cache(l, v2))::s1, xte, m, h')
  | C((Rapp _::_ as s1), Cache(_, (Abs(_, _, _) as v)), m, h) -> C(s1, v, m, h)
  | C(Rapp v2::s1,                                   i, m, h) -> C(s1, IApp(         i , v2) , m, h)
  | C(s2, Abs(x, t, e), m, h) -> let xm = x ^ "_" ^ string_of_int m in
                                 let l, h' = heap_alloc h in
                                 E(t, Dict.add x (Cache (l, V xm)) e, LAM xm::s2, m+1, h')
  | C(s2,         V x, m, h) -> S(s2, NVar x, m, h)
  | C(s2,  IApp(i, v), m, h) -> C(LAPP i::s2, v, m, h)
  | C(s2, Cache(l, v), m, h) -> M(Heap.find l h, l, s2, v, m, h)
  | M(Some y, _, s2, _, m, h) -> S(         s2, y, m, h)
  | M(  None, l, s2, v, m, h) -> C(CACHE l::s2, v, m, h)
  | S(CACHE l::s2, t, m, h) -> S(s2, t, m, Heap.add l (Some t) h)
  | S( LAPP i::s2, t, m, h) -> C(RAPP t::s2, i, m, h)
  | S(RAPP t2::s2, t, m, h) -> S(s2, NApp(t, t2), m, h)
  | S( LAM xm::s2, t, m, h) -> S(s2, NLam(xm, t), m, h))

let load (t : n_term) : conf = E(t, Dict.empty, [], 0, Heap.empty)

let unload : conf -> n_term =
  function
  | S([], t, _, _) -> t
  | _ -> assert false

open Trampoline

let nf_trampoline (t:n_term) : n_term trampolined =
  trampolined_map unload @@ opt_trampoline trans @@ load t

let normal_form (t : n_term) : n_term = pogo_stick @@ nf_trampoline t

let steps_for_family (ts:int -> n_term) (n:int) : int = count_bounces @@ nf_trampoline @@ ts n

let _ = [
  List.map (steps_for_family test1) @@ (ints 0 5 @ [9]) = [23; 38; 69; 124; 227; 6214];
  List.map (steps_for_family explode) @@ (ints 0 5 @ [9]) = [19; 34; 53; 72; 91; 186];
  List.map (steps_for_family explode2) @@ (ints 0 9) = [19; 37; 59; 81; 103; 125; 147; 169; 191];
  List.map (steps_for_family explode3) @@ (ints 0 9) = [19; 48; 81; 114; 147; 180; 213; 246; 279];
  List.map (steps_for_family explode4) @@ (ints 0 9) = [23; 52; 85; 118; 151; 184; 217; 250; 283];
] @ cbv_tests normal_form

