#### 作业内容描述：

1. Complete the type inference ( two substitution functions )
2. Complete the implementation for let-polymorphism
3. Think about how to handle recursive functions

#### 作业实现

##### 任务1
``````module List = Belt.List

type rec expr =
| CstI(int)
| CstB(bool)
| Var(string)
| If(expr, expr, expr)
| Fun(string, expr)
| App(expr, expr)

type rec typ = TInt | TBool | TVar(string) | TArr(typ, typ)

let rec to_string = (t: typ) =>
switch t {
| TInt => "Int"
| TBool => "Bool"
| TVar(x) => "@" ++ x
| TArr(x, y) => "(" ++ to_string(x) ++ "->" ++ to_string(y) ++ ")"
}

type context = list<(string, typ)>
type constraints = list<(typ, typ)>

let rec cs_to_string = (cs: constraints) =>
switch cs {
| list{} => ""
| list{(a, b), ...rest} =>
"[" ++ to_string(a) ++ " = " ++ to_string(b) ++ "], " ++ cs_to_string(rest)
}

let tvar_count = ref(0)
let new_tvar = (): typ => {
tvar_count.contents = tvar_count.contents + 1
TVar(Js.Int.toString(tvar_count.contents))
}

// 约束提取
let rec check_expr = (ctx: context, expr: expr): (typ, constraints) => {
switch expr {
| CstI(_) => (TInt, list{})
| CstB(_) => (TBool, list{})
| Var(x) => (
switch List.getAssoc(ctx, x, (a, b) => a == b) {
| Some(xt) => xt
| _ => assert false // 变量名必须在上下文中存在
},
list{},
)
| If(e1, e2, e3) => {
let (t1, c1) = check_expr(ctx, e1)
let (t2, c2) = check_expr(ctx, e2)
let (t3, c3) = check_expr(ctx, e3)
(t2, List.concatMany([c1, c2, c3, list{(t1, TBool), (t2, t3)}]))
}

| Fun(x, e) => {
let tx = new_tvar()
let (te, c) = check_expr(list{(x, tx), ...ctx}, e)
(TArr(tx, te), c)
}

| App(e1, e2) => {
let t = new_tvar()
let (t1, c1) = check_expr(ctx, e1)
let (t2, c2) = check_expr(ctx, e2)
(t, List.concatMany([c1, c2, list{(t1, TArr(t2, t))}]))
}

let (t1, c1) = check_expr(ctx, e1)
let (t2, c2) = check_expr(ctx, e2)
(TInt, List.concatMany([c1, c2, list{(t1, t2), (t1, TInt)}]))
}
}
}

let rec occurs = (x: string, t: typ): bool => {
switch t {
| TInt | TBool => false
| TVar(y) => x == y
| TArr(t1, t2) => occurs(x, t1) || occurs(x, t2)
}
}

type subst = list<(string, typ)>

let rec st_to_string = (st: subst) =>
switch st {
| list{} => ""
| list{(a, b), ...rest} => a ++ "= " ++ to_string(b) ++ ", " ++ st_to_string(rest)
}

let rec rest_subst = (cs: constraints, x: string, xt: typ): constraints => {
switch cs {
| list{} => cs
| list{c, ...rest} =>
list{
switch c {
| (TVar(y), yt) if y == x => (xt, yt)
| (yt, TVar(y)) if y == x => (xt, yt)
| _ => c
},
...rest_subst(rest, x, xt),
}
}
}

// 约束求解
let solve = (cs: constraints): subst => {
let rec go = (cs, s): subst => {
switch cs {
| list{} => s
| list{c, ...rest} =>
switch c {
| (TInt, TInt) | (TBool, TBool) => go(rest, s)
| (TArr(t1, t2), TArr(t3, t4)) => go(list{(t1, t3), (t2, t4), ...rest}, s)
| (TVar(x), t) | (t, TVar(x)) => {
assert !occurs(x, t) // 禁止递归类型
go(rest_subst(rest, x, t), list{(x, t), ...s})
}

| _ => assert false
}
}
}
go(cs, list{})
}

let rec subst_resolve = (x: string, s: subst): typ => {
switch s {
| list{} => TVar(x)
| list{(y, yt), ...rest} =>
if y == x {
yt
} else {
subst_resolve(x, rest)
}
}
}

let rec type_subst = (t: typ, s: subst): typ => {
switch t {
| TVar(x) => {
let y = subst_resolve(x, s)
if y == TVar(x) {
y
} else {
type_subst(y, s) // 递归求解，直到解无可解
}
}

| TArr(t1, t2) => TArr(type_subst(t1, s), type_subst(t2, s))
| _ => t
}
}

// 类型推导
let infer = (expr: expr): typ => {
let (t, cs) = check_expr(list{}, expr)
let s = solve(cs)
type_subst(t, s)
}

let test = Fun(
"f",
Fun("a", Fun("b", If(Var("a"), Add(App(Var("f"), Var("b")), CstI(1)), App(Var("f"), Var("a"))))),
)

let inferred = infer(test)
Js.log(inferred->to_string)

let omega = Fun("x", App(Var("x"), Var("x")))
let omega_inferred = infer(omega)
Js.log(omega_inferred->to_string)
``````
##### 任务2和3
``````module List = Belt.List

type rec expr =
| CstI(int)
| CstB(bool)
| Var(string)
| If(expr, expr, expr)
| Fun(string, expr)
| App(expr, expr)
| Let(string, expr, expr)

type rec typ = TInt | TBool | TArr(typ, typ) | TVar(string) | QVar(string) | IVar(string) // 使用IVar来标记已经实例化的类型变量

let rec to_string = (t: typ) =>
switch t {
| TInt => "Int"
| TBool => "Bool"
| TArr(x, y) => "(" ++ to_string(x) ++ "->" ++ to_string(y) ++ ")"
| TVar(x) => "@" ++ x
| IVar(x) => "#" ++ x
| QVar(x) => "\$" ++ x
}

type context = list<(string, typ)>
type constraints = list<(typ, typ)>

let rec cs_to_string = (cs: constraints) =>
switch cs {
| list{} => ""
| list{(a, b), ...rest} => to_string(a) ++ " = " ++ to_string(b) ++ ", " ++ cs_to_string(rest)
}

let var_count = ref(0)
let new_tvar = (): typ => {
var_count.contents = var_count.contents + 1
TVar(Js.Int.toString(var_count.contents))
}
let new_ivar = (): typ => {
var_count.contents = var_count.contents + 1
IVar(Js.Int.toString(var_count.contents))
}

let inst = (ty: typ): typ => {
let rec go = (ty: typ, ctx: context): (typ, context) => {
switch ty {
| TArr(t1, t2) => {
let (t1, ctx) = go(t1, ctx)
let (t2, ctx) = go(t2, ctx)
(TArr(t1, t2), ctx)
}

| QVar(x) =>
switch List.getAssoc(ctx, x, (a, b) => a == b) {
| Some(y) => (y, ctx)
| None => {
let y = new_ivar()
(y, list{(x, y), ...ctx})
}
}

| _ => (ty, ctx)
}
}
let (t, _) = go(ty, list{})
Js.log(ty->to_string)
Js.log(t->to_string)
Js.log("---")
t
}

let rec gen = (ty: typ): typ => {
switch ty {
| TArr(t1, t2) => TArr(gen(t1), gen(t2))
| TVar(x) => QVar(x)
| _ => ty
}
}

// 约束提取
let rec check_expr = (ctx: context, expr: expr): (typ, constraints) => {
switch expr {
| CstI(_) => (TInt, list{})
| CstB(_) => (TBool, list{})
| Var(x) => {
let ty = switch List.getAssoc(ctx, x, (a, b) => a == b) {
| Some(xt) => inst(xt)
| _ => assert false // 变量名必须在上下文中存在
}
(ty, list{})
}

| If(e1, e2, e3) => {
let (t1, c1) = check_expr(ctx, e1)
let (t2, c2) = check_expr(ctx, e2)
let (t3, c3) = check_expr(ctx, e3)
(t2, List.concatMany([c1, c2, c3, list{(t1, TBool), (t2, t3)}]))
}

| Fun(x, e) => {
let tx = new_tvar()
let (te, c) = check_expr(list{(x, tx), ...ctx}, e)
(TArr(tx, te), c)
}

| App(e1, e2) => {
let t = new_tvar()
let (t1, c1) = check_expr(ctx, e1)
let (t2, c2) = check_expr(ctx, e2)
(t, List.concatMany([c1, c2, list{(t1, TArr(t2, t))}]))
}

let (t1, c1) = check_expr(ctx, e1)
let (t2, c2) = check_expr(ctx, e2)
(TInt, List.concatMany([c1, c2, list{(t1, t2), (t1, TInt)}]))
}

| Let(x, e1, e2) => {
let xt = new_tvar() // 不知道这样处理递归对不对
let (t1, c1) = check_expr(list{(x, xt), ...ctx}, e1)
let (t2, c2) = check_expr(list{(x, gen(t1)), ...ctx}, e2)

(t2, List.concatMany([c1, c2]))
}
}
}

let rec occurs = (x: string, t: typ): bool => {
switch t {
| TInt | TBool => false
| TVar(y) | IVar(y) => x == y
| TArr(t1, t2) => occurs(x, t1) || occurs(x, t2)
| QVar(_) => assert false
}
}

type subst = list<(string, typ)>

let rec st_to_string = (st: subst) =>
switch st {
| list{} => ""
| list{(a, b), ...rest} => a ++ "= " ++ to_string(b) ++ ", " ++ st_to_string(rest)
}

let rec rest_subst = (cs: constraints, x: string, xt: typ): constraints => {
switch cs {
| list{} => cs
| list{c, ...rest} =>
list{
switch c {
| (TVar(y) | IVar(y), yt) if y == x => (xt, yt)
| (yt, TVar(y) | IVar(y)) if y == x => (xt, yt)
| _ => c
},
...rest_subst(rest, x, xt),
}
}
}

// 约束求解
let solve = (cs: constraints): subst => {
let rec go = (cs, s): subst => {
switch cs {
| list{} => s
| list{c, ...rest} =>
switch c {
| (TInt, TInt) | (TBool, TBool) => go(rest, s)
| (TArr(t1, t2), TArr(t3, t4)) => go(list{(t1, t3), (t2, t4), ...rest}, s)
| (TVar(x) | IVar(x), t) | (t, TVar(x) | IVar(x)) => {
assert !occurs(x, t) // 禁止递归类型
go(rest_subst(rest, x, t), list{(x, t), ...s})
}

| _ => assert false
}
}
}
go(cs, list{})
}

let rec subst_resolve = (x: string, s: subst): typ => {
switch s {
| list{} => TVar(x)
| list{(y, yt), ...rest} =>
if y == x {
yt
} else {
subst_resolve(x, rest)
}
}
}

let rec type_subst = (t: typ, s: subst): typ => {
switch t {
| TVar(x) => {
let y = subst_resolve(x, s)
if y == TVar(x) {
y
} else {
type_subst(y, s) // 递归求解，直到解无可解
}
}

| TArr(t1, t2) => TArr(type_subst(t1, s), type_subst(t2, s))
| _ => t
}
}

// 类型推导
let infer = (expr: expr): typ => {
let (t, cs) = check_expr(list{}, expr)
Js.log(t->to_string)
Js.log(cs->cs_to_string)
let s = solve(cs)
type_subst(t, s)
}

// let test = Let("a", Fun("x", Var("x")), Let("b", Var("a"), App(Var("b"), CstI(10))))

// let test = Let("id", Fun("x", Var("x")), Let("a", App(Var("id"), CstI(42)), Var("a")))

// 递归函数
let test = Let("a", Fun("x", App(Var("a"), Var("x"))), App(Var("a"), CstI(42)))

let inferred = infer(test)
Js.log(inferred->to_string)
``````
