(*
#directory "../common";;
#load "../common/namedTerm.cmo";;
#use "06b_cps.ml";;
*)

(* Here we transform the evaluator to the continuation passing
   style. This is the second part of the transformation: we rewrite
   all transformed functions such that they are called only by tail
   calls. *)


open NamedTerm

type 'a cache = 'a option ref

module Dict = Map.Make(struct type t = identifier let compare = compare end)

type value = Abs of identifier * n_term * env | Inert of inert
  | Cache of n_term cache * value
and inert = V of identifier | IApp of inert * value
  | ICache of n_term cache * inert
and env = value Dict.t

let gensym : unit -> int =
  let c = ref 0 in
  fun () ->
    let res = !c in
    c := res + 1;
    res

let rec env_lookup (x : identifier) (e : env) : value =
  match Dict.find_opt x e with
  | Some v -> v
  | None   -> Inert (V (x ^ "_free"))

let mount_cache (v:value) : value =
  match v with
  | Cache(_,_) -> v
  | _          -> Cache(ref None, v)


let rec eval (t : n_term) (e : env) (k : value -> 'a) : 'a =
  match t with
  | NVar x        -> k @@ env_lookup x e
  | NLam (x, t')  -> k @@ Abs (x, t', e)
  | NApp (t1, t2) -> eval t2 e (fun v2 ->
                     eval t1 e (fun v1 ->
                     apply_value v1 v2 k))
and apply_abs (x : identifier) (t : n_term) (e : env) (v : value) (k : value -> 'a) =
  eval t (Dict.add x (mount_cache v) e) k
and reify (v : value) (k : n_term -> n_term) : n_term =
  match v with
  | Abs (x, t', e) ->
    let xm = x ^ "_" ^ string_of_int (gensym ()) in
    apply_abs x t' e (Inert (V xm)) (fun v ->
    reify v (fun t'' ->
    k @@ NLam (xm,  t'')))
  | Inert i -> render_inert i k
  | Cache (c, v) -> cached_reify c v k
and cached_reify (d : n_term cache) (v : value) (k : n_term -> n_term) =
  match !d with
  | Some y -> k y
  | None   -> reify v (fun y ->
              d := Some y; k y)
and apply_value (v : value) (v2 : value) (k : value -> 'a) =
  match v with
  | Abs (x, t', e)     -> apply_abs x t' e v2 k
  |           Inert i  -> k @@ apply_inert i v2
  | Cache (c, Inert i) -> k @@ apply_inert (ICache (c, i)) v2
  | Cache (c,       v) -> apply_value v v2 k
and apply_inert (i : inert) (v' : value) : value =
  Inert (IApp (i, v'))
and render_inert (i : inert) (k : n_term -> n_term) : n_term =
  match i with
  | V x           -> k @@ NVar x
  | IApp (i, v)   -> reify v (fun n2 ->
                     render_inert i (fun n1 ->
                     k @@ NApp (n1, n2)))
  | ICache (c, i) -> cached_reify c (Inert i) k

let normal_form (t : n_term) : n_term = eval t Dict.empty (fun v -> reify v (fun t -> t))

let _ = cbv_tests normal_form

