## DEV Community # Writing an equation parser in Scala

I recently started diving into Scala and as a first learning project, I decided to build a functional mathematical expression tokenizer, parser and evaluator. In this post I will walk through the implementation and some of the core algorithms. Feel free to reach out with any questions!

You can find the final project here.

## Overview

The goal of the project is to be able to pass a mathematical expression (e.g. `1 + 2 * 3 - 4`) to the program and get the value of the expression back (i.e. `3`).

Internally, this involves three steps. First, the raw string expression needs to be tokenized into a well-defined format as preparation for further processing. Second, a parser reads the tokens and parses them into a mathematical expression following the usual precedence rules. Lastly, the expression is evaluated by combining its components.

## Allowed operations

Let's quickly define what the final program should be able to do.

Numbers should be specified as integers (e.g. `42`) or decimals (e.g. `1.23`) and should evaluated as doubles.

Negative numbers should be evaluated correctly, but should always be wrapped in parentheses (i.e. write `(-42)+2` instead of `-42+2`).

The following operations should be allowed:

``````+    Addition
-    Difference
*    Multiplication
/    Division
^    Power
``````

Left `(` and right `)` parentheses should also be allowed and the resulting precedence should be taken into account.

## Defining Expressions

We start by defining the expressions. In total, we will need to define the expressions `Number`, `Difference`, `Product`, `Division`, and `Power`.

In fact, each component of an expression is itself an expression. For example, `1 + 2` consists of two expressions of type `Number`, namely `Number(1)` and `Number(2)`, as well as a `Sum` operator.

Moreover, every expression is either of type `Number` (which wraps a double) or `Operator` (which has a left and a right expression). Continuing with the example from above, the `Sum` operator has the left expression `Number(1)` and the right expression `Number(2)`.

Lastly, we differentiate between regular operators and `Commutative` operators. A commutative operator does not differentiate between left and right, as the order doesn't matter. For example, `x+y` always equals `y+x` (commutative) but `x-y` does not necessarily equal `y-x`.

With that in mind, we can define the `Expression` trait and the `Operator` trait, which extends `Expression`.

``````/** Represents arithmetic expression. */
trait Expression

/** Arithmetic operator */
trait Operator extends Expression {
val left: Expression
val right: Expression
}
``````

As laid out, a `Commutative` expression is simply an `Operator` that doesn't differentiate between left and right. We can achieve this behavior by extending `Operator` and overriding the `equals` def and `hashCode` def accordingly.

``````/**
* Represents operators whose left and right side are commutative,
* i.e. the order of the LHS and RHS expression does not matter.
* */
trait Commutative extends Operator {
def canEqual(a: Any) = a.isInstanceOf[Commutative]

override def equals(that: Any): Boolean =
that match {
case that: Commutative => {
that.canEqual(this) &&
((this.left == that.left && this.right == that.right) ||
(this.left == that.right && this.right == that.left))
}
case _ => false
}

override def hashCode: Int = {
val prime = 31
prime + left.hashCode * right.hashCode
}
}
``````

The actual expression types can now be implemented as case classes that extend any of the `Expression` traits.

``````case class Number(n: Double) extends Expression {
def value = n
}

case class Sum(left: Expression, right: Expression) extends Commutative
case class Difference(left: Expression, right: Expression) extends Operator
case class Product(left: Expression, right: Expression) extends Commutative
case class Division(left: Expression, right: Expression) extends Operator
case class Power(left: Expression, right: Expression) extends Operator
``````

Evaluating these expressions is now almost trivial. Any operator can be evaluated by simply applying the underlying mathematical operation to the left and right sides (which in turn need to be evaluated). For `Number` expressions, we simply return the underlying value. Scala's pattern matching comes in handy here.

``````/**
* Evaluate an expression.
* @param expression: expression to evaluate
* @return value of the expression
* */
def evaluate(expression: Expression): Double = expression match
case Number(n) => n
case Sum(left, right) => evaluate(left) + evaluate(right)
case Difference(left, right) => evaluate(left) - evaluate(right)
case Product(left, right) => evaluate(left) * evaluate(right)
case Division(left, right) => evaluate(left) / evaluate(right)
case Power(left, right) => scala.math.pow(evaluate(left), evaluate(right))
``````

That's the expression implementation done. If we wanted to represent the example `1 + 2 * 3 - 4` from earlier using our implementation, it would give us

`Difference(Sum(Number(1), Product(Number(2), Number(3))), Number(4))`

We can also express these expressions as a syntax tree. For example, parsing the equation `3+4*2/(1-5)^2^3` creates the following syntax tree. The final version of `Expression.scala` can be found here.

## Tokenizer

Next, we will take care of tokenizing the raw string expression.

The tokenizer splits the raw expression into its components and encodes each component as a token. The token closely follow the `Expression` classes defined earlier, however, we will also have to handle parentheses here.

We start by defining a simple abstract class `Token` and an abstract subclass `OperatorToken`, which represents operators. Each operator has a `precedence` (lower precedence operators are evaluated first).

``````/** Token of an expression */
abstract class Token()

/** Token representing an operator */
abstract class OperatorToken() extends Token {
/** Precedence value of the operator */
def precedence: Int
}
``````

Next, we have to handle how the different operators associate. A left associative operator evaluates from left to right, while a right associative takes the opposite direction.

This is implemented using the following three traits.

``````trait Associates

/** Associates left */
trait Left extends Associates

/** Associates right */
trait Right extends Associates
``````

This gives us all the components to define our actual tokens. `NumberToken`, `LeftParensToken`, and `RightParensToken` extend `Token` directly. All others are `OperatorToken` subclasses an also extend an `Associates` trait. Only the power token (`^`) associates right, all others associate left.

``````/** Token representing sum */
case class SumToken() extends OperatorToken, Left {
def precedence = 2
}

/** Token representing difference */
case class DifferenceToken() extends OperatorToken, Left {
def precedence = 2
}

/** Token representing product */
case class ProductToken() extends OperatorToken, Left {
def precedence = 3
}

/** Token representing division */
case class DivisionToken() extends OperatorToken, Left {
def precedence = 3
}

/** Token representing power */
case class PowerToken() extends OperatorToken, Right {
def precedence = 4
}

/** Token representing number */
case class NumberToken(n: Double) extends Token

/** Token representing left parenthesis */
case class LeftParensToken() extends Token

/** Token representing right parenthesis */
case class RightParensToken() extends Token
``````

What remains is the function to actually convert a raw string expression into a list of tokens.

The function will have the following signature (we will fill in the body step by step).

``````/**
* Tokenize string expression.
*
* Tokenize a string representation of an arithmentic expression.
*
* @param rawExpression: String representation of expression.
* @return List of tokens of the expression.
* */
def tokenize(rawExpression: String): List[Token] = ???
``````

In the method body, we will have to add the following.

### Splitting the raw string expression

``````    val splitted = rawExpression
.filterNot(_.isWhitespace)
.split("(?=[)(+/*-])|(?<=[)(+/*-])|(?=[\\^])|(?<=[\\^])")
.map(_.trim)
``````

This takes the raw string, splits it into its parts using a regular expression, and removes any whitespace.

### Tokenizing string components

The val `splitted` now contains a list of strings that can be converted into tokens. For that, we define the `tokenizeOne` function and map it to each element of `splitted`.

``````// Regex representing a double.
val numPattern = "(\\-?\\d*\\.?\\d+)".r

/**
* Tokenize a single string.
*
* @param x: String to tokenize
* @returns Corresponding token
* */
def tokenizeOne(x: String) = x match {
case "+" => SumToken()
case "-" => DifferenceToken()
case "*" => ProductToken()
case "/" => DivisionToken()
case "^" => PowerToken()
case "(" => LeftParensToken()
case ")" => RightParensToken()
case numPattern(c: String) => NumberToken(c.toDouble)
case _ => throw RuntimeException(s""""\$x is not legal""")
}

// tokenize each element
val tokenized = splitted.map(tokenizeOne).toList
``````

### Handling negative numbers

At this stage, the tokenizer can handle all (valid) input and convert it into a list of tokens. However, we would run into problems later, if we parsed negative numbers (e.g. `1+(-1)`) since the `-` is converted into a `DifferenceToken`, yet there is not `NumberToken` to its left.

To handle this, we will walk through the list of tokens, and if we find a `(-` sequence in the list, we insert a `NumberToken(0)` between the `(` and the `-`, producing a valid negative number.

The algorithm could be implemented more concisely with a loop, however, I wanted to find a functional implementation. Thus, I came up with the following:

``````/**
* Handle negative numbers.
*
* Negative numbers are prefixed with a zero
* e.g. (-1) -> (0-1)
* to maintian both a left and right expression of the Difference operator.
*
* @param tokens: list of tokens to handle
* @returns tokens with inserted zeros
* */
def handleNegative(tokens: List[Token]): List[Token] =

/**
* Insert prefix zero.
*
* If the token list starts with "(, -" insert a zero.
*
* @param tokens: list of params to check
* */
def insert(tokens: List[Token]): List[Token] = tokens match {
case a :: b :: rest => {
a match {
case _a: LeftParensToken => {
b match {
case _b: DifferenceToken => _a :: NumberToken(0) :: Nil
case _ => _a :: Nil
}
}
case _ => a :: Nil
}
}
case a :: Nil => a :: Nil
case _ => Nil
}

/** Recursively insert zeros where necessary */
def recur(tokens: List[Token]): List[Token] =
if tokens.isEmpty then Nil
else if tokens.tail.isEmpty then tokens
else insert(tokens) ++ recur(tokens.tail)

recur(tokens)
``````

The algorithm recursively explores the list. If the list starts with the defined pattern of `(-`, it returns `(0-` plus the tail of the list. Else, it simply returns the tail of the list.

In the body of `tokenize` we now simply call

``````// handle negative numbers
handleNegative(tokenized)
``````

as the last line.

The final `Tokenizer.scala` file can be found here

## Parser

In this last step, we can finally write the actual parser, that converts the token list into an expression that can be evaluated.

The parser will work in two steps.

1. Convert the token list from infix to postfix notation
2. Parse the postfix token list into an expression

### Converting infix to postfix

When we input the expression as a raw string, we use what is called infix notation. While this makes sense for human readability, the parsing algorithm of the implementation requires postfix notation.

To given an example, the infix expression `(5-6) * 7` expressed in postfix notation is `* -5 6 7`.

The algorithm to do this conversion used here is called Shunting Yard algorithm. My implementation is based on the pseudo code provided on Wikipedia, but implemented recursively. I won't go into much detail as to how the algorithm works as I think the Wikipedia article does a pretty good job at explaining that.

Let's start by defining the core recursive procedure of the function.

``````/**
* Shunting Yard algorithm.
*
* Converts a list of tokens from infix to postfix notation.
* https://en.wikipedia.org/wiki/Shunting_yard_algorithm
*
* @param tokens: List of tokens in infix notation
* @return list of tokens in postfix notation
* */
def shuntingYard(tokens: List[Token]): List[Token] =

// todo: we will fill in the helper functions later

/**
* Recursive method of shunting yard.
*
* @param stack: Stack of tokens left to place
* @param postfix: Tokens converted to postfix notation
* @param tokens: Tokens in infix notation
* */
@tailrec
def recur(stack: Stack[Token], postfix: List[Token], tokens: List[Token]): List[Token] =
tokens match {
case Nil => postfix ++ stack
case t :: rest => {
t match {
case n: NumberToken => recur(stack, postfix :+ n, rest)
case o: OperatorToken => {
if (stack.isEmpty) then recur(stack.push(o), postfix, rest)
else {
val updated = operatorUpdate(postfix, stack, o)
recur(updated._1, updated._2, rest)
}
}
case l: LeftParensToken => recur(stack.push(l), postfix, rest)
case r: RightParensToken => {
val updated = rightParensUpdate(postfix, stack)
recur(updated._1, updated._2, rest)
}
}
}
}

recur(new Stack[Token], List(), tokens)
``````

Essentially, this procedure recursively walks over the token list and depending on the token it encounters, it either places the token on a stack or combines the token with the last token from the step.

Notice that we're still missing some helper functions that were crated to make the core procedure more readable. Insert the following helper functions into `shuntingYard` before the `recur` function.

``````/** Helper method to determine if token if left associative. */
def isLeftAssoc(t: Token): Boolean = t match {
case a: Left => true
case _ => false
}

/** Helper method to determine if token if right associative. */
def isRightAssoc(t: Token): Boolean = t match {
case a: Right => true
case _ => false
}

/** Helper method to determine if token if left parenthesis. */
def isLeftParens(t: Token): Boolean = t match {
case a: LeftParensToken => true
case _ => false
}

/** Helper method to update postfix and stack during operator parsing. */
@tailrec
def operatorUpdate(postfix: List[Token], stack: Stack[Token], o: OperatorToken): (Stack[Token], List[Token]) =
def matchCond(o: OperatorToken, stack: Stack[Token]): Boolean = stack.head match {
case o2: OperatorToken => {
(isLeftAssoc(o) && (o.precedence <= o2.precedence)) ||
(isRightAssoc(o) && (o.precedence < o2.precedence))
}
case _ => false
}
if (stack.isEmpty || !matchCond(o, stack)) (stack.push(o), postfix)
else operatorUpdate(postfix :+ stack.pop, stack, o)

/** Helper method to update postfix and stack during right parens parsing. */
@tailrec
def rightParensUpdate(postfix: List[Token], stack: Stack[Token]): (Stack[Token], List[Token]) =
else rightParensUpdate(postfix :+ stack.pop, stack)
``````

The `shuntingYard` function returns a new token list in postfix notation.

### Parsing postfix token list into an expression

Given the postfix token list, the second step of the parser converts it into a full expression.

The algorithm is quite simple. It recursively walks through the token lists. If a number token is encountered, it is placed on a stack. If an operator is encountered, the last two elements from the stack are combined accordingly and the resulting expression is put on the stack. Eventually, the stack only contains one expression, which is the final one.

``````/**
* Parses RPN to expression.
*
* Takes a list of tokens in RPN and parses the expression representation.
* https://en.wikipedia.org/wiki/Reverse_Polish_notation
*
* @param tokens: List of tokens in postfix notation
* @return parsed expression
* */
def parsePostfix(tokens: List[Token]): Expression =

/** Helper method to determine if token is operator. */
def isOperator(t: Token): Boolean = t match {
case a: OperatorToken => true
case _ => false
}

/**
* Recursive method of the algorithm.
*
* @param stack: Stack of expressions to parse
* @param tokens: Tokens to parse.
* @return parsed expression
* */
@tailrec
def recur(stack: Stack[Expression], tokens: List[Token]): Expression = tokens match {
case Nil => stack.pop
case t :: rest => {
if (isOperator(t)) t match {
case t: SumToken => stack.push(Sum(stack.pop, stack.pop))
case t: DifferenceToken => val x = stack.pop; stack.push(Difference(stack.pop, x))
case t: ProductToken => stack.push(Product(stack.pop, stack.pop))
case t: DivisionToken => val x = stack.pop; stack.push(Division(stack.pop, x))
case t: PowerToken => val x = stack.pop; stack.push(Power(stack.pop, x))
case _ => throw new RuntimeException(s""""\$t" is not an operator""")
}
else t match {
case t: NumberToken => stack.push(Number(t.n))
case _ => throw new RuntimeException(s""""\$t" is not valid here""")
}
recur(stack, rest)
}
}

recur(new Stack[Expression], tokens)
``````

### Putting it together

With the `shuntingYard` and `parsePostfix` functions in place, we can define the `parse` function, which we can call to actually perform the parsing.

``````/**
* Run the parser.
*
* Converts the tokens to postfix notation and then reverses it to RPN.
*
* @param tokens: List of tokens in infix notation
* @return Parsed expression
* */
def parse(tokens: List[Token]): Expression =
val postfix = shuntingYard(tokens)
parsePostfix(postfix)
``````

And that is the parser done. You can find the final version of `Parser.scala` here.

## Main.scala

To make the project runnable, we need a main method in `Main.scala`. This is just a few lines:

``````object Main {
def main(args: Array[String]) = args match
case Array(x: String) => printResult(getResult(x))
case Array() => throw new java.lang.IllegalArgumentException("Too few arguments!")
case _ => throw new java.lang.IllegalArgumentException("Too many arguments!")

/** Tokenize and parse expression. */
def getResult(rawExpr: String): Expression = tokenize(rawExpr).parser
/** Print the result of the evaluation. */
def printResult(expr: Expression): Unit = println(expr.eval)
}
``````

## Testing the parser

Let's try out the parser on some expressions. Fire up your `sbt` server and run some examples.

``````sbt:equation-parser> run "1+2"
3.0
``````
``````sbt:equation-parser> run "3 * (1 + 2) ^ 7"
6561.0
``````
``````sbt:equation-parser> run "100 / 8 - (2 * 3) + 4 ^ 3"
70.5
``````
``````sbt:equation-parser> run "((3 + 2) * (2 + 1)) ^ 2"
225.0
``````
``````sbt:equation-parser> run "(-42)^3+(-42)"
-74130.0
``````

Seems to work!

(We should probably write some more thorough tests to make sure it actually works, which I did here)

## Conclusion

Implementing this Tokenizer/Parser/Evaluator in a functional way in Scala proved to be a very insightful and fun learning project. Maybe it will inspire anyone to build their own parser or improve upon mine.

Feel free to reach out with any comments or feedback!