type identifier = string
type n_term = NVar of identifier | NApp of n_term * n_term | NLam of identifier * n_term;;

let rec n_term_fold_right (var: string -> 'a) (app: 'a -> 'a -> 'a) (lam: string -> 'a -> 'a)
  (t:n_term) : 'a =
  let rec_call = n_term_fold_right var app lam in
  match t with
  | NVar x     -> var x
  | NApp(u, v) -> app (rec_call u) (rec_call v)
  | NLam(x, e) -> lam x (rec_call e);;

let meta_id (x : 'a) : 'a = x;;

let meta_const (x : 'a) : 'b -> 'a = fun _ -> x;;

(* Pretty printer for NLambda terms *)
let n_term_repr : n_term -> string = n_term_fold_right
  meta_id
  (fun l r -> "("^l^" "^r^")")
  (fun x e -> "(λ"^x^"."^e^")");;

(* Iteration is used in construction of complex examples of terms *) 
let rec iter (n:int) (f:'a -> 'a) (x:'a) : 'a =
  match n with
  | 0 -> x
  | _ -> iter (n-1) f (f x);;

(* Integer ranges are used to generate term families *)
let rec ints (b:int) (e:int) : int list =
  if b >= e then [] else b :: ints (b+1) e;;

(* Here we start building examples of NLambda terms. The simplest one
   is the identity function. *)
let id : n_term = NLam("x", NVar "x");;

(* The following function is used to build from a list of terms a term
   consisting of a sequence of NApplications of the terms from the list *) 

let appseq (es:n_term list) : n_term =
  let rec appseq_aux (e:n_term) (es:n_term list) : n_term =
    match es with
    | [] -> e
    | e'::es' -> appseq_aux (NApp(e,e')) es' in
  match es with
  | [] -> id
  | e'::es' -> appseq_aux e' es';;


(* More examples of NLambda terms *) 
let yes = NLam("x", NLam("y", NVar "x"));;
let no  = NLam("x", NLam("y", NVar "y"));;

let const = yes;;

let zero = no;;

let church (n:int) : n_term = NLam("f", NLam("x", iter n (fun e -> NApp(NVar "f", e)) (NVar "x") ));;

let is_zero = NLam("n", appseq[NVar "n"; NApp(const, no); yes] );;

let succ = NLam("n", NLam("f", NLam("x", NApp(NApp(NVar "n", NVar "f"), NApp(NVar "f", NVar "x")))));;

let addition = NLam("n", NLam("m", appseq[NVar "n"; succ; NVar "m"]));;

let multiplication = NLam("n", NLam("m", appseq[NVar "n"; (NApp(addition, NVar "m")); zero]));;

let composition = NLam("g", NLam("f", NLam("x", NApp(NVar "g", NApp(NVar "f", NVar "x")))));;

let koma = appseq [yes; no; yes; no; yes];;

let inert1 = NLam("x", appseq[NVar "x"; NApp(id, NVar "x"); NVar "x"]);;

let omega : n_term = NLam("x", NApp(NVar "x", NVar "x"));;

let big_omega : n_term = NApp(omega, omega);;

let singleton_of (a:n_term) : n_term = NLam("f", NApp(NVar "f", a));;

let pair_of (a:n_term) (b:n_term) : n_term = NLam("f", appseq[NVar "f"; a; b]);;

let pair : n_term = NLam("x", NLam("y", NLam("f", appseq[NVar "f"; NVar "x"; NVar "y"])));;

let dubleton : n_term = NLam("x", pair_of (NVar "x") (NVar "x"));;

let fst : n_term = singleton_of yes;;

let snd : n_term = singleton_of no;;

let pred_aux : n_term = NLam("p", pair_of (NApp(snd, NVar "p")) (NApp(succ,NApp(snd, NVar "p"))));;

let pred = NLam("n", appseq[NVar "n"; pred_aux; pair_of zero zero; yes]);;

let monus = NLam("n", NLam("m", appseq[NVar "m"; pred; NVar "n"]));;

let alpha = NApp( omega, NLam("x", NLam("y", NApp(NVar "x", NVar "y"))));;

let alpha2 = NLam("x", NApp(NLam("y", NApp(NVar "y", NLam("z", NVar "y"))), NVar "x"));;

let alpha3 = NLam("x", NApp(NLam("y", NApp(NVar "y", NLam("z", NVar "y"))), NApp(NVar "x", NVar "x")));;

let example : n_term = NApp(NLam("x", NLam("y", NLam("z", NApp(NVar "x", NVar "y")))), NLam("w", NApp(NVar "w", NVar "w")));;

let y_comb : n_term = NLam("f", let t = NLam("x", NApp(NVar "f",             NApp(NVar "x", NVar "x")           )) in NApp(t,t));;
let y_cbv  : n_term = NLam("f", let t = NLam("x", NApp(NVar "f", NLam("z", appseq[NVar "x"; NVar "x"; NVar "z"]))) in NApp(t,t));;

let factorial_step     : n_term = NLam("f", NLam("n", appseq[is_zero;NVar "n"; church 1; appseq[composition; NVar "n"; NApp(NVar "f", NApp(pred,NVar "n"))]] ));;
let factorial_step_cbv : n_term = NLam("f", NLam("n", appseq[is_zero;NVar "n"; church 1; appseq[composition; NVar "n"; NLam("z", appseq[NVar "f"; NApp(pred, NVar "n"); NVar "z"])]] ));;

let factorial     = NApp(y_comb, factorial_step);;
let factorial_cbv = NApp(y_cbv,  factorial_step_cbv);;

(* à la Crégut 2007 *)
let test1 (n:int) : n_term = appseq [church n; church 2; id];;
let test2 (n:int) : n_term = NApp(pred, church n);;

let lambda_back : n_term = NLam("x", appseq[NVar "x"; NLam("y", NVar "x"); NVar "x"]);;

let explode (n:int) : n_term = NLam("x", appseq[church n; omega; NVar "x"]);;

let explode2 (n:int) : n_term = NLam("x", appseq[church n; NLam("z", NApp(NVar "z", NLam ("y", NVar "z"))); NVar "x"]);;

let inert_redex (n:int) : n_term = NLam("x", appseq[church n;
  NLam("y", NApp(NVar "y", NLam ("z", NApp(NLam("w", NVar "y"), NVar "y"))));
  NVar "x"]);;

let explode3 (n:int) : n_term = NLam("x", appseq[church n; dubleton; NVar "x"]);;

let explode4 (n:int) : n_term =           appseq[church n; dubleton; id];;

let implode2 (n:int) : n_term =           appseq[church n; dubleton; NLam("x", NApp(id, NVar "x"))];;

let rec implode (n:int) : n_term =
  match n with
  | 0 -> NVar ("y")
  | 1 -> NApp(dubleton, id)
  | _ -> NApp(dubleton, NLam("_", implode (n-1)));;

let contained_explosion (n:int) : n_term = appseq[church n; fst; (explode4 n)];;

let rec quadratic_a : int -> n_term = function
  | 0 -> NVar "x"
  | n -> appseq[quadratic_a (n-1); NLam("y", NVar "y"); NVar "x"]

let rec quadratic_b : int -> n_term = function
  | 0 -> NVar "z"
  | n -> NApp(NVar "z", NLam("w", quadratic_b (n-1))) 

let quadratic_q (n : int) : n_term =
  NLam("x", NApp(NLam("z", quadratic_b n), quadratic_a n)) 

(*let square_step (n:int) : n_term =
  let rec aux (m:int) : term =
    NLam("y", match m with
    | 0 -> NVar n
    | _ -> NApp( NVar(n-m), aux (m-1) )) in
  NLam("x", NApp(aux n, iter n (fun t -> NApp(NVar "x", t)) @@ NVar "x"));;*)

let list_indexof (p : 'a -> bool) (xs : 'a list) : int option =
  let rec aux (xs : 'a list) (acc : int) =
    match xs with
    | x::xs' -> if p x then Some acc else aux xs' (acc + 1)
    | []     -> None
  in aux xs 0;;

(* only for closed terms *)
let rec alpha_equivalent (t1 : n_term) (e1 : string list) (t2 : n_term) (e2 : string list) : bool =
  match t1, t2 with
  | NApp(l1, r1), NApp(l2, r2) -> alpha_equivalent l1 e1 l2 e2 && alpha_equivalent r1 e1 r2 e2
  | NLam(x1, b1), NLam(x2, b2) -> alpha_equivalent b1 (x1::e1) b2 (x2::e2)
  |     NVar(x1),     NVar(x2) -> list_indexof (String.equal x1) e1 = list_indexof (String.equal x2) e2
  |            _,            _ -> false;;

let ( =@ ) (t1 : n_term) (t2 : n_term) : bool = alpha_equivalent t1 [] t2 [];;

(* A sequence of tests, parameterized by an evaluation method. We
   check if a given evaluation method correctly evaluate these terms
   to their normal forms. The expected result is a list consisting of
   constants "true". *)
let common_tests (eval : n_term -> n_term) : bool list = [
  eval succ =@ succ;
  eval alpha =@ NLam("x", NLam("y", NApp(NVar "x", NVar "y")));
  eval alpha2 =@ NLam("x", NApp(NVar "x", NLam("y", NVar "x")));
  eval alpha3 =@ NLam("x", NApp(NApp(NVar "x", NVar "x"), NLam("y", NApp(NVar "x", NVar "x"))));
  (* List.map (fun n -> eval @@ NApp(is_zero, church n)) @@ ints 0 5 =@ [yes;no;no;no;no]; *)
  eval (appseq[addition;church 5;church 8]) =@ church 13;
  eval (appseq[multiplication;church 3;church 4]) =@ church 12;
  eval (appseq[composition;church 3;church 4]) =@ church 12;
  eval (appseq[church 4;church 4]) =@ church 256;
  eval (NApp(pred_aux, pair_of (church 5) (church 7))) =@ pair_of (church 7) (church 8);
  eval inert1 =@ NLam("x", appseq[NVar "x"; NVar "x"; NVar "x"]);
  eval koma =@ yes;
  eval example =@ NLam("x", NLam("y", NApp(NVar "x", NVar "x")));
  eval lambda_back =@ lambda_back;
  eval (explode2 2) =@ eval (inert_redex 2);
];;

let cbn_tests (eval : n_term -> n_term) : bool list = [
  eval (NApp(no,big_omega)) =@ id;
  eval (NApp(singleton_of big_omega, no)) =@ id;
  eval (NLam("_", NApp(factorial, church 5))) =@ NLam("_", church 120)
] @ common_tests eval;;

let cbv_tests (eval : n_term -> n_term) : bool list = [
  eval (NApp(no,NLam("_", big_omega))) =@ id;
  eval (NLam("_", NApp(factorial_cbv, church 5))) =@ NLam("_", church 120);
  eval (pair_of (NApp(factorial_cbv, church 4)) (appseq [addition; church 2; church 2])) =@ pair_of (church 24) (church 4);
] @ common_tests eval;;


(* term measures *)

let var_count : n_term -> int = n_term_fold_right
  (meta_const 1) (+) (fun _ -> meta_id);;

let app_count : n_term -> int = n_term_fold_right
  (meta_const 0) (fun a b -> a + b + 1) (fun _ -> meta_id);;

let lam_count : n_term -> int = n_term_fold_right
  (meta_const 0) (+) (fun _ -> Pervasives.succ);;

let lam_depth : n_term -> int = n_term_fold_right
  (meta_const 0) max (fun _ -> Pervasives.succ);;

(* caputre-agnostic substitution *)

let rec n_subst (x: identifier) (t: n_term) (s: n_term) : n_term =
  match s with
  | NVar  y      -> if x = y then t else s
  | NLam( y, t1) -> if x = y then s else NLam(y, n_subst x t t1)
  | NApp(t1, t2) -> NApp(n_subst x t t1, n_subst x t t2)

(* explicit normal forms *)

type normal_n_term = ENLam of identifier * normal_n_term | ENNeut of neutral_n_term
and neutral_n_term = ENVar of identifier | ENApp of neutral_n_term * normal_n_term;;

let rec n_term_of_normal : normal_n_term -> n_term = function
  | ENLam (x,e) -> NLam(x, n_term_of_normal e)
  | ENNeut e    -> n_term_of_neutral e
and n_term_of_neutral : neutral_n_term -> n_term = function
  | ENVar n    -> NVar n
  | ENApp(u,v) -> NApp(n_term_of_neutral u, n_term_of_normal v);;

(* substitutable terms *)

type 'a sn_term = SNVar of identifier
  | SNApp of 'a sn_term * 'a sn_term
  | SNLam of identifier * 'a sn_term
  | NSubs of 'a

let rec to_substitutable : n_term -> 'a sn_term =
  function
  | NVar x     -> SNVar x
  | NApp(u, v) -> SNApp(to_substitutable u, to_substitutable v)
  | NLam(x, t) -> SNLam(x, to_substitutable t)

let rec from_substitutable : 'a sn_term -> n_term =
  function
  | NSubs _     -> assert false
  | SNVar x     -> NVar x
  | SNApp(u, v) -> NApp(from_substitutable u, from_substitutable v)
  | SNLam(x, t) -> NLam(x, from_substitutable t)

let rec exec_subs (i : 'a -> 'b sn_term) : 'a sn_term -> 'b sn_term =
  function
  | NSubs a     -> i a
  | SNVar x     -> SNVar x
  | SNApp(u, v) -> SNApp(exec_subs i u, exec_subs i v)
  | SNLam(x, t) -> SNLam(x, exec_subs i t)

let rec subst (x : identifier) (v : 'a sn_term) (t : 'a sn_term) : 'a sn_term =
  match t with
  | NSubs _     -> t
  | SNApp(m, n) -> SNApp(subst x v m, subst x v n)
  | SNLam(y, m) -> if x  = y then t else SNLam(y, subst x v m)
  | SNVar y     -> if x <> y then t else v

