DEV Community

Adi Salimgereyev
Adi Salimgereyev

Posted on

Implementing Type inference in Rust. Part #1: Unification

If you’re writing your own programming language, then you’ve probably heard of type inference. In this series of articles, without too much theory, we will clearly go through how it works and implement our own in Rust.

What is type inference?

Type inference in the context of compilers is the stage of compilation in which the compiler infers types for expressions. Type inference is mostly inherent in functional languages, but other groups of languages can also implement it.

Algorithm

Disclaimer: There will be a minimum of theory and some terms and definitions may not be used.

Before doing type analysis, our compiler needs to parse our source. And it is in this article that we will use the concepts of AST and HIR to a minimum and will consider type inference using code examples, not trees.

Let’s imagine that we have code in a hypothetical programming language for which we want to implement type inference:

func inc(a) {
  return a + 1
}
Enter fullscreen mode Exit fullscreen mode

Since the types are not specified in the function signature, we will have to infer them ourselves. The first thing we need to do is create type variables:

func inc(a: ?1) -> ?2 {
  return a + 1;
}
Enter fullscreen mode Exit fullscreen mode

Type variable — in this case, denotes a type that we have not yet inferred.

For simplicity and clarity, let’s assume that the signature of the + operator is: (int, int) -> int. Then, from a + 1 it follows that the type of parameter a must be int and the type of expression 1 must also be int. And since a + 1 is equal to int and is the inc function’s return value, the function’s return value is int. Thus, we have received information about the relationship between types or type constraints:

?1 = int
int = int
?2 = int
Enter fullscreen mode Exit fullscreen mode

Type constraint — simply put, is a way to record information about relationships between types.

Here we use equality constraint, but we could also use subtyping, for example, in numeric types (which we will cover in this article).

Now that we have a list of these very relationships, we can find out the values of variable types, this is called unification. The bottom line is to get the values of type variables themselves from the system of equations, which is a list of type constraints:

?1 = int
?2 = int
Enter fullscreen mode Exit fullscreen mode

Readers might wonder what happens if the equation has no solution. The answer is simple — mismatched types error. For example:

func inc(a: ?1) -> ?2 {
  return a + "hello";
}

?1 = int
String = int
?2 = int
Enter fullscreen mode Exit fullscreen mode

Since String cannot be equal to int, the type constraint is unmet, and the equation has no solutions. Then we can see the mismatched types error:

Image description

Let’s take a more interesting example:

func generate_nums(count) {
  var a = [];
  for (var i = 0; i < count; i++) {
    a.insert(i);
  }
  return a;
}
Enter fullscreen mode Exit fullscreen mode

First, as in the previous example, mark up type variables for names and function signatures:

func generate_nums(count: ?1) -> ?2 {
  var a: ?3 = [];
  for (var i: ?4 = 0; i < count; i++) {
    a.insert(i);
  }
  return a;
}
Enter fullscreen mode Exit fullscreen mode

Now if we look at the place where we declare a — we can see that an empty array is created. We can use a type variable to indicate the type of array elements:

?3 = Array<?5>
Enter fullscreen mode Exit fullscreen mode

Moving to the next statement:

for (var i: ?4 = 0; i < count; i++) {
Enter fullscreen mode Exit fullscreen mode

Suppose that the signature of the operator ++ is: (int) -> int and the signature of the operator <: (int, int) -> bool. So we have 3 new constraints:

?4 = int
?4 = ?1
?4 = int
Enter fullscreen mode Exit fullscreen mode

Let’s continue!

a.insert(i);
Enter fullscreen mode Exit fullscreen mode

Let’s suppose the insert method of an array is defined as something like this: List<T>.insert(element: T). Then for the generic T we need one more type variable:

?3 = Array<?6> // a
?6 = ?4 // i
Enter fullscreen mode Exit fullscreen mode

And now analyzing the last return statement:

func generate_nums(count: ?1) -> ?2 {
  var a: ?3 = [];
  ...
  return a;
}
Enter fullscreen mode Exit fullscreen mode

The return value of the function is a, which means that:

?3 = ?2
Enter fullscreen mode Exit fullscreen mode

Thus, we obtain a system of equations:

?3 = Array<?5>
?4 = int
?4 = ?1
?4 = int
?3 = Array<?6>
?6 = ?4
?3 = ?2
Enter fullscreen mode Exit fullscreen mode

And again, using the **unification **algorithm, we can solve this system of equations and get:

?1 = int
?2 = Array<int>
?3 = Array<int>
?4 = int
?5 = int
?6 = int
Enter fullscreen mode Exit fullscreen mode

Substituting these types, we get the result of type inference:

func generate_nums(count: int) -> Array<int> {
  var a: Array<int> = [];
  for (var i: int = 0; i < count; i++) {
    a.insert(i);
  }
  return a;
}
Enter fullscreen mode Exit fullscreen mode

Now, perhaps the reader has a question, how does this mysterious unification algorithm work that solves a system of equations from type constraints? Now we’ll figure it out!

Implementation

First, let’s write a representation of types in our language:

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum Type {
    Constructor(TypeConstructor),
    Variable(TypeVariable),
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct TypeConstructor {
    name: String,
    generics: Vec<Arc<Type>>,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct TypeVariable(usize);
Enter fullscreen mode Exit fullscreen mode

So that:

  • int is TypeConstructor { name: “int”, generics: Vec::new() }.
  • List<int> is TypeConstructor { name: “List”, generics: vec![TypeConstructor { name: “int”, generics: vec![] }] }.
  • ?1 is TypeVariable(1).

Now we can start implementing the unification algorithm:

fn unify(left: Arc<Type>, right: Arc<Type>, 
        substitutions: &mut HashMap<TypeVariable, Arc<Type>>) {
  match (left.as_ref(), right.as_ref()) {
Enter fullscreen mode Exit fullscreen mode

If both types are type constructors, then we check that they are equal and unify their generic parameters:

     (
          Type::Constructor(TypeConstructor {
              name: name1,
              generics: generics1,
          }),
          Type::Constructor(TypeConstructor {
              name: name2,
              generics: generics2,
          }),
      ) => {
          assert_eq!(name1, name2);
          assert_eq!(generics1.len(), generics2.len());

          for (left, right) in zip(generics1, generics2) {
              unify(left.clone(), right.clone(), substitutions);
          }
      }
Enter fullscreen mode Exit fullscreen mode

For example from:

Array<int> = Array<?1>
Enter fullscreen mode Exit fullscreen mode

Follows that:

int = ?1
Enter fullscreen mode Exit fullscreen mode

If we get two different type constructors, then we have a type mismatch:

Array<...> != Option<...>
Enter fullscreen mode Exit fullscreen mode

If both sides are equal type variables, then everything is fine and we do nothing:

  (Type::Variable(TypeVariable(i)), 
   Type::Variable(TypeVariable(j))) if i == j => {}
Enter fullscreen mode Exit fullscreen mode

If not, then we add the value of the variable to the storage and, importantly, check whether we have created an infinite type.

  (_, Type::Variable(v @ TypeVariable(..))) => {
      if let Some(substitution) = substitutions.get(&v) {
          unify(left, substitution.clone(), substitutions);
          return;
      }

      assert!(!v.occurs_in(left.clone(), substitutions));
      substitutions.insert(*v, left);
  }
  (Type::Variable(v @ TypeVariable(..)), _) => {
      if let Some(substitution) = substitutions.get(&v) {
          unify(right, substitution.clone(), substitutions);
          return;
      }

      assert!(!v.occurs_in(right.clone(), substitutions));
      substitutions.insert(*v, right);
  }
Enter fullscreen mode Exit fullscreen mode

An example of when we try to create an infinite type in Rust:

Image description

In this example, generic in push — T, its value is of type a.to_vec(), that is, Vec<T>. We get T = Vec<T>. The only possible solution for this constraint is Vec<Vec<Vec<Vec<Vec<Vec<Vec<….>>>>>>>. Of course, there are languages that allow this, but in this case, for simplicity and to avoid problems, we will not accept such types.

Let’s now implement occurs_in, which checks whether the type is present in the generic arguments of another if that constructor, or equal to it if it is a variable:

impl TypeVariable {
    fn occurs_in(&self, ty: Arc<Type>, 
                 substitutions: &HashMap<TypeVariable, Arc<Type>>) -> bool {
        match ty.as_ref() {
            Type::Variable(v @ TypeVariable(i)) => {
                if let Some(substitution) = substitutions.get(&v) {
                    if substitution.as_ref() != &Type::Variable(*v) {
                        return self.occurs_in(substitution.clone(), substitutions);
                    }
                }

                self.0 == *i
            }
            Type::Constructor(TypeConstructor { generics, .. }) => {
                for generic in generics {
                    if self.occurs_in(generic.clone(), substitutions) {
                        return true;
                    }
                }

                false
            }
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

We will also create a function that will recursively go through our store of values of type variables in order to completely remove them, that is, for example:

?1 = ?2
?2 = ?3
?3 = int

substitute(?1) = substitute(?2) = substitute(?3) = int
Enter fullscreen mode Exit fullscreen mode
impl Type {
    fn substitute(&self, substitutions: &HashMap<TypeVariable, Arc<Type>>) -> Arc<Type> {
        match self {
            Type::Constructor(TypeConstructor { name, generics }) => {
                Arc::new(Type::Constructor(TypeConstructor {
                    name: name.clone(),
                    generics: generics
                        .iter()
                        .map(|t| t.substitute(substitutions))
                        .collect(),
                }))
            }
            Type::Variable(TypeVariable(i)) => {
                if let Some(t) = substitutions.get(&TypeVariable(*i)) {
                    t.substitute(substitutions)
                } else {
                    Arc::new(self.clone())
                }
            }
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

We did it, we wrote a unification algorithm! Let’s check it out in practice! Remember the previous example?

?3 = Array<?5>
?4 = int
?4 = ?1
?4 = int
?3 = Array<?6>
?6 = ?4
?3 = ?2
Enter fullscreen mode Exit fullscreen mode

Before simulating it, let’s write two macros:

The first macro, will shortly generate a type variable:

macro_rules! tvar {
    ($i:expr) => {
        Arc::new(Type::Variable(TypeVariable($i)))
    };
}
Enter fullscreen mode Exit fullscreen mode

The second macro will shortly generate a type constructor:

macro_rules! tconst {
    ($name:expr,$($generic:expr)*) => {
        Arc::new(Type::Constructor(TypeConstructor {
            name: $name.to_string(),
            generics: vec![$($generic),*],
        }))
    };
    ($name:expr) => { tconst!($name,) };
}
Enter fullscreen mode Exit fullscreen mode

Now let’s simulate our previous example on our implementation of the unification algorithm:

fn main() {
    let mut substitutions = HashMap::new();

    unify(tvar!(3), tconst!("Array", tvar!(5)), &mut substitutions);
    unify(tvar!(4), tconst!("int"), &mut substitutions);
    unify(tvar!(4), tvar!(1), &mut substitutions);
    unify(tvar!(4), tconst!("int"), &mut substitutions);
    unify(tvar!(3), tconst!("Array", tvar!(6)), &mut substitutions);
    unify(tvar!(6), tvar!(4), &mut substitutions);
    unify(tvar!(3), tvar!(2), &mut substitutions);

    for i in 1..=6 {
        println!(
            "{}: {:?}",
            i,
            Type::Variable(TypeVariable(i)).substitute(&substitutions)
        );
    }
}
Enter fullscreen mode Exit fullscreen mode

We get:

1: Constructor(TypeConstructor { name: "int", generics: [] })
2: Constructor(TypeConstructor { name: "Array", generics: [Constructor(TypeConstructor { name: "int", generics: [] })] })
3: Constructor(TypeConstructor { name: "Array", generics: [Constructor(TypeConstructor { name: "int", generics: [] })] })
4: Constructor(TypeConstructor { name: "int", generics: [] })
5: Constructor(TypeConstructor { name: "int", generics: [] })
6: Constructor(TypeConstructor { name: "int", generics: [] })
Enter fullscreen mode Exit fullscreen mode

In the next article, we’ll start writing our little programming language and start defining the types of simple expressions like literals or arrays!

Top comments (0)