DEV Community

Hercules Lemke Merscher
Hercules Lemke Merscher

Posted on • Originally published at dev.to

Tsonnet #38 - Call me maybe, but make it typed, part 3

Welcome to the Tsonnet series!

If you're not following along, check out how it all started in the first post of the series.

In the previous post, we implemented closures — both bound to locals and immediately invoked:

I was annoyed by the fact that function type variants were difficult to read, so went on an errand to make it simpler. We can do better!

Let's see how it turned out.

Telephone

A makeover for the AST

diff --git a/lib/ast.ml b/lib/ast.ml
index e317b6e..4b1102e 100644
--- a/lib/ast.ml
+++ b/lib/ast.ml
@@ -90,10 +90,9 @@ type expr =
   | Local of position * (string * expr) list
   | Seq of expr list
   | IndexedExpr of position * string * expr
-  | FunctionDef of position * (string * (string * expr option) list * expr)
-  | FunctionCall of position * string * expr list
-  | Closure of position * ((string * expr option) list * expr)
-  | ClosureCall of position * (string * expr option) list * expr * expr list
+  | FunctionDef of position * function_def
+  | FunctionCall of position * function_call
+  | Closure of position * closure

 and object_entry =
   | ObjectField of string * expr
Enter fullscreen mode Exit fullscreen mode

ClosureCall is gone. It's basically FunctionCall, so there's no need to have a duplicated type variant.

The type variants function_def, function_call, and closure are represented as records, which makes them much easier to reason about since we can have names to represent what they are holding:

@@ -102,6 +101,20 @@ and object_scope =
   | Self
   | TopLevel
   | ObjVarRef of string
+
+and function_def = {
+  name: string;
+  params: (string * expr option) list;
+  body: expr;
+}
+and function_call = {
+  callee: expr;
+  args: expr list;
+}
+and closure = {
+  params: (string * expr option) list;
+  body: expr;
+}
 [@@deriving show]

 let dummy_expr = Unit
Enter fullscreen mode Exit fullscreen mode

With this AST change, it became easier to simplify the parsing rules:

diff --git a/lib/parser.mly b/lib/parser.mly
index 8b3821e..e7b7771 100644
--- a/lib/parser.mly
+++ b/lib/parser.mly
@@ -22,6 +22,7 @@
 %token DOT
 %token SELF TOP_LEVEL_OBJ
 %token PLUS MINUS MULTIPLY DIVIDE MODULO
+%nonassoc FUNCTION
 %left PLUS MINUS
 %left MULTIPLY DIVIDE MODULO
 %token <string> ID
@@ -43,13 +44,11 @@
 prog:
   | e = expr; EOF { e }
   | e = expr_seq; EOF { e }
-  | e = closure; EOF { e }
   ;

 expr:
   | e = assignable_expr { e }
   | e = vars { e }
-  | e = closure_call { e }
   ;

 expr_seq:
@@ -63,6 +62,7 @@ assignable_expr:
   | e = indexed_expr { e }
   | e = obj_field_access { e }
   | e = funcall { e }
+  | e = closure { e }
   ;

 indexed_expr:
@@ -175,7 +175,6 @@ obj_field_access:

 var:
   | varname = ID; ASSIGN; e = assignable_expr { (varname, e) }
-  | varname = ID; ASSIGN; e = closure { (varname, e) }
   ;

 vars:
@@ -196,7 +195,7 @@ fundef:
   | fname = ID;
     LEFT_PAREN; params = separated_nonempty_list(COMMA, fundef_param); RIGHT_PAREN;
     ASSIGN;
-    body = fundef_body { (fname, params, body) }
+    body = fundef_body { { name = fname; params = params; body = body } }
   ;

 fundef_body:
@@ -207,20 +206,23 @@ fundef_body:
 funcall:
   | fname = ID;
     LEFT_PAREN; params = separated_nonempty_list(COMMA, assignable_expr); RIGHT_PAREN
-    { FunctionCall (with_pos $startpos $endpos, fname, params) }
+    { FunctionCall
+      (with_pos $startpos $endpos, {
+        callee = Ident (with_pos $startpos(fname) $endpos(fname), fname);
+        args = params;
+      })
+    }
+  | callee = scoped_expr;
+    LEFT_PAREN; params = separated_nonempty_list(COMMA, assignable_expr); RIGHT_PAREN
+    { FunctionCall (with_pos $startpos $endpos, { callee = callee; args = params }) }
   ;

 closure:
   | FUNCTION;
     LEFT_PAREN; params = separated_nonempty_list(COMMA, fundef_param); RIGHT_PAREN;
-    body = assignable_expr { Closure (with_pos $startpos $endpos, (params, body)) }
-  ;
-
-closure_call:
-  | LEFT_PAREN; FUNCTION;
-    LEFT_PAREN; def_params = separated_nonempty_list(COMMA, fundef_param); RIGHT_PAREN;
-    body = assignable_expr;
-    RIGHT_PAREN;
-    LEFT_PAREN; call_params = separated_nonempty_list(COMMA, assignable_expr); RIGHT_PAREN;
-    { ClosureCall (with_pos $startpos $endpos, def_params, body, call_params) }
+    body = assignable_expr {
+      Closure (with_pos $startpos $endpos, { params = params; body = body })
+    }
+    (* precedence here will transform "function(x) x * x" into "function(x) (x * x)" *)
+    %prec FUNCTION
   ;
Enter fullscreen mode Exit fullscreen mode

I could finally move closure out of prog and expr. It was never meant to be there. But part of the problem was not having FUNCTION flagged as non-associative.

The interpreter catches up

With ClosureCall gone and FunctionCall now carrying a callee: expr, the interpreter needed a bit more surgery -- resolving the callee before dispatching to apply_function:

diff --git a/lib/interpreter.ml b/lib/interpreter.ml
index 7412fb9..15f9932 100644
--- a/lib/interpreter.ml
+++ b/lib/interpreter.ml
@@ -27,9 +27,8 @@ let rec interpret env expr =
   | Seq exprs -> interpret_seq env exprs
   | IndexedExpr (pos, varname, index_expr) -> interpret_indexed_expr env (pos, varname, index_expr)
   | FunctionDef (pos, def) -> interpret_function_def env (pos, def)
-  | FunctionCall (pos, fname, params) -> interpret_function_call env (pos, fname, params)
+  | FunctionCall (pos, call) -> interpret_function_call env (pos, call)
   | Closure _ -> ok (env, expr)
-  | ClosureCall (pos, def_params, body, params) -> interpret_closure_call env (pos, def_params, body, params)

 and interpret_indexed_expr env (pos, varname, index_expr) =
   let* (env', index_expr') = interpret env index_expr in
@@ -417,8 +416,8 @@ and interpret_ident env pos varname =
     result
   end

-and interpret_function_def env (pos, (fname, params, body)) =
-  let env' = Env.add_local fname (FunctionDef (pos, (fname, params, body))) env in
+and interpret_function_def env (pos, def) =
+  let env' = Env.add_local def.name (FunctionDef (pos, def)) env in
   ok (env', Unit)

 and apply_function env pos def_params body call_params =
@@ -468,16 +467,23 @@ and apply_function env pos def_params body call_params =
     in
     ok (env, result)

-and interpret_function_call env (pos, fname, call_params) =
-  match Env.find_opt fname env with
-  | Some (Closure (pos, (def_params, body)))
-  | Some (FunctionDef (pos, (_, def_params, body))) ->
-    apply_function env pos def_params body call_params
+and interpret_function_call env (pos, call) =
+  let* (env', callee_val) =
+    match call.callee with
+    | Ident (pos, name) ->
+      (match Env.find_opt name env with
+      | Some expr -> ok (env, expr)
+      | None -> Error.error_at pos (Error.Msg.var_not_found name)
+      )
+    | _ -> interpret env call.callee
+  in
+  match callee_val with
+  | Closure (_, closure) ->
+    apply_function env' pos closure.params closure.body call.args
+  | FunctionDef (_, def) ->
+    apply_function env' pos def.params def.body call.args
   | _ ->
-    Error.error_at pos (Error.Msg.var_not_found fname)
-
-and interpret_closure_call env (pos, def_params, body, call_params) =
-  apply_function env pos def_params body call_params
+    Error.error_at pos (Error.Msg.var_not_found (string_of_type call.callee))

 let rec deep_eval expr =
   match expr with
Enter fullscreen mode Exit fullscreen mode

The type checker gets the memo

The type checker can also be improved by using the new type variants as records:

diff --git a/lib/type.ml b/lib/type.ml
index ee912b3..ad4f333 100644
--- a/lib/type.ml
+++ b/lib/type.ml
@@ -17,10 +17,10 @@ type tsonnet_type =
   | TobjectPtr of Env.env_id * t_object_scope
   | Lazy of expr
   | Tunresolved
-  | TfunctionDef of (string * tsonnet_type) list (* params: name * type *) * expr (* body *) * tsonnet_type (* return *)
-  | TfunctionCall of tsonnet_type list * tsonnet_type
-  | Tclosure of (string * tsonnet_type) list * expr
-  | TclosureCall of tsonnet_type list * tsonnet_type
+  | TfunctionDef of t_function_def
+  | TfunctionCall of t_function_call
+  | Tclosure of t_closure
+
 and t_object_entry =
   | TobjectField of string * tsonnet_type
   | TobjectExpr of tsonnet_type
@@ -28,6 +28,20 @@ and t_object_scope =
   | TobjectSelf
   | TobjectTopLevel

+and t_function_def = {
+  params: (string * tsonnet_type) list;
+  body: expr;
+  return: tsonnet_type;
+}
+and t_function_call = {
+  params: tsonnet_type list;
+  return: tsonnet_type;
+}
+and t_closure = {
+  params: (string * tsonnet_type) list;
+  body: expr;
+}
+
 let rec to_string = function
   | Tunit -> "()"
   | Tnull -> "Null"
@@ -59,25 +73,21 @@ let rec to_string = function
       | TobjectTopLevel -> "$"
     in Printf.sprintf "%s (%d)" s id
   | Lazy ty -> string_of_type ty
-  | TfunctionDef (params, _, return) ->
+  | TfunctionDef {params; return; _} ->
     Printf.sprintf "function(%s) -> %s"
       (List.map (fun (name, ty) -> name ^ ": " ^ to_string ty) params
       |> String.concat ", "
       )
       (to_string return)
-  | TfunctionCall (params_type, return) ->
+  | TfunctionCall {params=params_type; return} ->
     Printf.sprintf "function(%s) -> %s"
       (List.map to_string params_type |> String.concat ", ")
       (to_string return)
-  | Tclosure (params, _) ->
+  | Tclosure {params; _} ->
     Printf.sprintf "function(%s)"
       (List.map (fun (name, ty) -> name ^ ": " ^ to_string ty) params
       |> String.concat ", "
       )
-  | TclosureCall (params_type, return) ->
-    Printf.sprintf "function(%s) -> %s"
-      (List.map to_string params_type |> String.concat ", ")
-      (to_string return)
   | Tunresolved -> "<unresolved>"

 let rec collect_free_idents = function
@@ -95,10 +105,9 @@ let rec collect_free_idents = function
   | ObjectFieldAccess (_, _, exprs) -> List.concat_map collect_free_idents exprs
   | IndexedExpr (_, name, e) -> name :: collect_free_idents e
   | Local (_, vars) -> List.concat_map (fun (_, e) -> collect_free_idents e) vars
-  | FunctionCall (_, name, args) -> name :: List.concat_map collect_free_idents args
-  | Closure (_, (_, body)) -> collect_free_idents body
-  | ClosureCall (_, _, body, args) ->
-    collect_free_idents body @ List.concat_map collect_free_idents args
+  | FunctionCall (_, call) ->
+    collect_free_idents call.callee @ List.concat_map collect_free_idents call.args
+  | Closure (_, closure) -> collect_free_idents closure.body
   | _ -> []

 let reachable_bindings bindings initial_idents =
@@ -192,10 +201,8 @@ let rec translate venv expr =
   | UnaryOp (pos, op, expr) -> translate_unary_op venv (pos, op, expr)
   | IndexedExpr (pos, varname, index_expr) -> translate_indexed_expr venv (pos, varname, index_expr)
   | FunctionDef (pos, def) -> translate_function_def venv (pos, def)
-  | FunctionCall (pos, fname, params) -> translate_function_call venv (pos, fname, params)
-  | Closure (pos, (params, body)) -> translate_closure venv (pos, params, body)
-  | ClosureCall (pos, def_params, body, call_params) ->
-    translate_closure_call venv (pos, def_params, body, call_params)
+  | FunctionCall (pos, call) -> translate_function_call venv (pos, call)
+  | Closure (pos, closure) -> translate_closure venv (pos, closure)
   | expr' ->
     error (Error.Msg.type_invalid_expr (string_of_type expr')) 
@@ -493,7 +500,7 @@ and translate_bin_op venv pos op e1 e2 =
   | _, Tany, _ | _, _, Tany -> ok (venv'', Tany)
   | _ -> Error.error_at pos Error.Msg.invalid_binary_op

-and translate_function_def venv (pos, (fun_name, params, body)) =
+and translate_function_def venv (pos, def) =
   (* For params with defaults, we can infer the type from the default expression;
      params without defaults remain Tunresolved until the first call *)
   let* params_typed = List.fold_left
@@ -507,36 +514,57 @@ and translate_function_def venv (pos, (fun_name, params, body)) =
         ok (params' @ [(name, Tunresolved)])
     )
     (ok [])
-    params
+    def.params
+  in
+  let fun_def = TfunctionDef
+    { params = params_typed;
+      body = def.body;
+      return = Tunresolved;
+    }
   in
-  let fun_def = TfunctionDef (params_typed, body, Tunresolved) in
   (* So, function declaration will have an unresolved type definition,
      that only later it will be translated: before function call translation!
      After first function call, concrete types are set and subsequent calls will
      type check against the initial type assignment(s). *)
-  let venv' = Env.add_local fun_name fun_def venv in
+  let venv' = Env.add_local def.name fun_def venv in
   ok (venv', fun_def)

-and translate_function_call venv (pos, fname, call_params) =
-  (* 1. retrieve TfunctionDef from venv *)
-  match Env.find_opt fname venv with
-  | Some (TfunctionDef (def_params, body_expr, return_type)) ->
-    (* check arity: allow fewer args if defaults exist *)
-    let num_call = List.length call_params in
+and translate_function_call venv (pos, call) =
+  match call.callee with
+  | Ident (_, name) ->
+    translate_named_function_call venv (pos, name, call.args)
+  | Closure (_, closure) ->
+    translate_closure_call venv (pos, closure.params, closure.body, call.args)
+  | _ ->
+    let* (venv', callee_ty) = translate venv call.callee in
+    (match callee_ty with
+    | Tclosure { params = closure_params; body = body_expr } ->
+      let def_params = List.map (fun (name, _ty) -> (name, None)) closure_params in
+      translate_closure_call venv' (pos, def_params, body_expr, call.args)
+    | _ ->
+      Error.error_at pos (Error.Msg.type_invalid_expr (string_of_type call.callee))
+    )
+
+and translate_named_function_call venv (pos, name, args) =
+  match Env.find_opt name venv with
+  | Some (TfunctionDef { params = def_params;
+                         body = body_expr;
+                         return = return_type;
+                       }) ->
+    let num_call = List.length args in
     let num_def = List.length def_params in
     if num_call > num_def
     then
       Error.error_at pos
         (Error.Msg.wrong_number_of_params num_def num_call)
     else
-      (* 2. type check each positional parameter passed in the function call *)
       let* (venv', resolved_params) =
         List.fold_left
           (fun acc (index, (param_name, def_param_type)) ->
             let* (venv', params') = acc in
             if index < num_call
             then
-              let call_param = List.nth call_params index in
+              let call_param = List.nth args index in
               let* (venv'', call_param_type) = translate venv' call_param in
               match def_param_type with
               | Tunresolved ->
@@ -549,26 +577,22 @@ and translate_function_call venv (pos, fname, call_params) =
                     ~expected:(to_string expected)
                     ~got:(to_string call_param_type))
             else
-              (* default arg — resolve type from the original AST default expression *)
               ok (venv', params' @ [(param_name, def_param_type)])
           )
           (ok (venv, []))
           (List.mapi (fun i p -> (i, p)) def_params)
       in
-      (* 3. type check return *)
       let body_venv = List.fold_left
         (fun env (name, ty) -> Env.add_local name ty env)
         venv'
         resolved_params
       in
-      (* translate the body with resolved param types in scope *)
       let* (_, body_type) = translate body_venv body_expr in
-      let* resolved_return = match return_type with
+      let* resolved_return =
+        match return_type with
         | Tunresolved ->
-          (* 3a. first call: infer return type from body *)
           ok body_type
         | expected ->
-          (* 3b. subsequent calls: check body type matches *)
           if body_type = expected
           then ok expected
           else Error.error_at pos
@@ -576,23 +600,26 @@ and translate_function_call venv (pos, fname, call_params) =
               ~expected:(to_string expected)
               ~got:(to_string body_type))
       in
-      (* 4. update env with the now-resolved function type *)
-      let resolved_fun = TfunctionDef (resolved_params, body_expr, resolved_return) in
-      let venv_with_resolved_fun = Env.add_local fname resolved_fun venv' in
+      let resolved_fun =
+        TfunctionDef {
+          params = resolved_params;
+          body = body_expr;
+          return = resolved_return;
+        }
+      in
+      let venv_with_resolved_fun = Env.add_local name resolved_fun venv' in
       ok (venv_with_resolved_fun, resolved_return)
   | Some (Lazy expr) ->
     let* (venv', resolved) = translate venv expr in
-    (* Re-dispatch with the resolved type *)
-    let venv'' = Env.add_local fname resolved venv' in
-    translate_function_call venv'' (pos, fname, call_params)
-  | Some (Tclosure (def_params, body_expr)) ->
-    translate_closure_call venv (pos, 
-      List.map (fun (name, _ty) -> (name, None)) def_params,
-      body_expr, call_params)
+    let venv'' = Env.add_local name resolved venv' in
+    translate_named_function_call venv'' (pos, name, args)
+  | Some (Tclosure { params = closure_params; body = body_expr }) ->
+    let def_params = List.map (fun (name, _ty) -> (name, None)) closure_params in
+    translate_closure_call venv (pos, def_params, body_expr, args)
   | _ ->
-    Error.error_at pos (Error.Msg.var_not_found fname)
+    Error.error_at pos (Error.Msg.var_not_found name)

-and translate_closure venv (_pos, params, body) =
+and translate_closure venv (_pos, closure) =
   let* params_typed = List.fold_left
     (fun acc (name, default) ->
       let* params' = acc in
@@ -604,11 +631,11 @@ and translate_closure venv (_pos, params, body) =
         ok (params' @ [(name, Tunresolved)])
     )
     (ok [])
-    params
+    closure.params
   in
-  ok (venv, Tclosure (params_typed, body))
+  ok (venv, Tclosure { params = params_typed; body = closure.body })

-and translate_closure_call venv (pos, def_params, body_expr, call_params) =
+and translate_closure_call venv (pos, def_params, body, call_params) =
   let num_call = List.length call_params in
   let num_def = List.length def_params in
   let num_required =
@@ -641,7 +668,7 @@ and translate_closure_call venv (pos, def_params, body_expr, call_params) =
       venv'
       resolved_params
     in
-    let* (_, body_type) = translate body_venv body_expr in
+    let* (_, body_type) = translate body_venv body in
     ok (venv, body_type)

 let check (config : Config.t) expr  =
Enter fullscreen mode Exit fullscreen mode

Conclusion

Switching from tuples to named records across the AST, type checker, and interpreter was one of those refactors that feels almost mechanical — until it isn't. The real win was collapsing ClosureCall into FunctionCall and generalising the callee to an arbitrary expression. The grammar conflict that made closure a second-class citizen in the previous post? Gone. A bit of precedence annotation and a small restructuring of funcall was all it took.

Here is the entire diff.

Next up, methods!


Thanks for reading Bit Maybe Wise! Tuples are out, records are in. Your inbox deserves named fields too — subscribe and stop guessing what position two holds.

Photo by Quino Al on Unsplash

Top comments (0)