DEV Community

Sarun Tapee
Sarun Tapee

Posted on

Construction of Expression Tree

After getting the postfix form of the mathematical expression, we can construct an expression tree and evaluate the result of the expression.

prefix form:
a + sqr(max(b,c)) * d

postfix form:
a b c max sqr d * + 

expression tree:

    +
   / \
  a   *
     / \
   sqr  d
   /
  max
 /  \
b    c

Enter fullscreen mode Exit fullscreen mode

postfix to expression tree algorithm

The algorithm to construct expression treee is as follows.

For each element in postfix
    if operand
        push to stack
    if unary operator
        left_child = stack.pop
        elem->left = left_child
        push elem to stack
    if binary operator
        right_child = stack.pop
        left_child = stack.pop
        elem->left = left_child
        elem->right = right_child
        push elem to stack
root = stack.pop
Enter fullscreen mode Exit fullscreen mode

Evaluate the result of the expression tree

To evaluate the result of the expression tree, all we need to do is postorder traversal of the tree and return the result.

Define tree data type

As we can see from the tree structure, we can observe that all the leaf nodes are operands, and the internal nodes are operators. We will implement the tree class as follows class diagram.

MathExprTree and MathExprNode

// math_expr_tree.h
class MathExprNode {
protected:
    MathExprNode(); // protected constructor to prevent creating an object of this class
public:
    virtual ~MathExprNode(); // make destructor virtual to not only call the destructor of the base class on a pointer of MathExprNode
    MathExprNode *parent;
    MathExprNode *left;
    MathExprNode *right;
    virtual math_exp_type resultType() = 0;
    virtual DType *eval() = 0;
};

class MathExprTree {
private:
    void clear(MathExprNode *node);
public:
    MathExprNode *root;
    MathExprTree(lex_data **postfix, int size);
    virtual ~MathExprTree();
    void inorderPrint(); // just to check the tree after construct
    bool valid(); // check if the data type and operator are compatible
    DType *eval(); // evaluate the result of the expression tree

};

// math_expr_tree.cpp
static void _inorderPrint(MathExprNode *node); 

MathExprNode::MathExprNode() {
    this->parent = NULL;
    this->left = NULL;
    this->right = NULL;
}

MathExprNode::~MathExprNode() {}

MathExprTree::MathExprTree(lex_data **postfix, int size) {
    lex_data *l;

    for (int i = 0; i < size; i++) {
        l = postfix[i];
        if (l->token_code == MATH_SPACE) continue;
        assert(is_operator(l->token_code) || is_operand(l->token_code));
        if (is_operand(l->token_code)) {
            DType *o = DType::factory(l->token_code, l->text, get_val);
            s.push(o);
        } else if (is_unary_op(l->token_code)) {
            MathExprNode *node = Operator::factory(l->token_code);
            MathExprNode *left = s.top(); s.pop();
            left->parent = node;
            node->left = left;
            s.push(node);
        } else {
            MathExprNode *node = Operator::factory(l->token_code);
            MathExprNode *right = s.top(); s.pop();
            MathExprNode *left = s.top(); s.pop();
            node->left = left;
            node->right = right;
            left->parent = node;
            right->parent = node;
            s.push(node);
        }
    }
    this->root = s.top(); s.pop();
    assert(s.empty());
    assert(!this->root->parent);
}

void MathExprTree::clear(MathExprNode *node) {
    if (!node) return;
    clear(node->left);
    clear(node->right);
    delete node;
}

MathExprTree::~MathExprTree() {
    clear(this->root);
    this->root = nullptr;
}

void MathExprTree::inorderPrint() {
    _inorderPrint(this->root);
}

bool MathExprTree::valid() {
    return this->root && this->root->resultType() != MATH_INVALID;
}


DType *MathExprTree::eval() {
    if (!this->root) return nullptr;
    return this->root->eval();
}

static void _inorderPrint(MathExprNode *node) {
    if (!node) return;
    _inorderPrint(node->left);
    Operator *o = dynamic_cast<Operator *>(node);
    if (o) {
        printf("%s ", o->symbol.c_str());
    } else {
        DType *t = dynamic_cast<DType *>(node);
        assert(t);
        switch (t->id) {
        case MATH_INTEGER_VALUE: printf("%d ", ((DTypeInt *) t)->val); break;
        case MATH_DOUBLE_VALUE: printf("%f ", ((DTypeDouble *) t)->val); break;
        default: assert(0);
        }
    }
    _inorderPrint(node->right);
}

Enter fullscreen mode Exit fullscreen mode

Operator

// operator.h
class Operator: public MathExprNode {
protected:
    Operator();
public:
    virtual ~Operator();
    math_exp_type id;
    std::string symbol;
    bool is_unary;
    static Operator *factory(math_exp_type id);
    virtual math_exp_type resultType() = 0;
    virtual DType *eval() = 0;
};

class OperatorPlus: public Operator {
private:
    math_exp_type resultType(math_exp_type a, math_exp_type b); 
public:
    OperatorPlus();
    ~OperatorPlus();
    virtual math_exp_type resultType() override;
    virtual DType *eval() override;
};

class OperatorMinus: public Operator {
private:
    math_exp_type resultType(math_exp_type a, math_exp_type b); 
public:
    OperatorMinus();
    ~OperatorMinus();
    virtual math_exp_type resultType() override;
    virtual DType *eval() override;
};

// operator.cpp
Operator::Operator(){
    this->id = MATH_INVALID;
    this->symbol = "";
    this->is_unary = false;
};
Operator::~Operator(){};

Operator *Operator::factory(math_exp_type id) {
    switch (id) {
    case MATH_PLUS: return new OperatorPlus();
    case MATH_MINUS: return new OperatorMinus();
    default: assert(0);
    }
}

OperatorPlus::OperatorPlus() {
    this->id = MATH_PLUS;
    this->symbol = "+";
    this->is_unary = false;
};
OperatorPlus::~OperatorPlus() {};

math_exp_type OperatorPlus::resultType(math_exp_type a, math_exp_type b) {
    switch (comb(a, b)) {
    case comb(MATH_INTEGER_VALUE, MATH_INTEGER_VALUE): return MATH_INTEGER_VALUE;
    case comb(MATH_DOUBLE_VALUE, MATH_DOUBLE_VALUE): return MATH_DOUBLE_VALUE;
    case comb(MATH_WILDCARD_VALUE, MATH_WILDCARD_VALUE): return MATH_WILDCARD_VALUE;
    case comb(MATH_INTEGER_VALUE, MATH_DOUBLE_VALUE): return MATH_DOUBLE_VALUE;
    case comb(MATH_WILDCARD_VALUE, MATH_DOUBLE_VALUE): return MATH_DOUBLE_VALUE;
    case comb(MATH_INTEGER_VALUE, MATH_WILDCARD_VALUE): return MATH_INTEGER_VALUE;
    default: return MATH_INVALID;
    }
}

math_exp_type OperatorPlus::resultType() {
    if (!this->left || !this->right) return MATH_INVALID;
    return resultType(this->left->resultType(), this->right->resultType());
}

DType *OperatorPlus::eval() {
    if (!this->left || !this->right) {
        return nullptr;
    }
    DType *res;
    DType *left = this->left->eval();
    DType *right = this->right->eval();

    res = nullptr;
    do {
        if (!left || !right) break;
        switch (left->id) {
        case MATH_INTEGER_VALUE:
            switch (right->id) {
            case MATH_INTEGER_VALUE:
                    res = new DTypeInt();
                    ((DTypeInt *) res)->val = ((DTypeInt *) left)->val + ((DTypeInt *) right)->val;
                    break;
            case MATH_DOUBLE_VALUE:
                    res = new DTypeDouble();
                    ((DTypeDouble *) res)->val = (double) ((DTypeInt *) left)->val + ((DTypeDouble *) right)->val;
                    break;
            }
            break;
        case MATH_DOUBLE_VALUE:
            switch (right->id) {
            case MATH_INTEGER_VALUE:
                    res = new DTypeDouble();
                    ((DTypeDouble *) res)->val = ((DTypeDouble *) left)->val + (double) ((DTypeInt *) right)->val ;
                    break;
            case MATH_DOUBLE_VALUE:
                    res = new DTypeDouble();
                    ((DTypeDouble *) res)->val = ((DTypeDouble *) left)->val + ((DTypeDouble *) right)->val ;
                    break;
            }
            break;
        }
    } while(0);

    delete left;
    delete right;
    return res;
}

OperatorMinus::OperatorMinus() {
    this->id = MATH_MINUS;
    this->symbol = "-";
    this->is_unary = false;
};
OperatorMinus::~OperatorMinus() {};

math_exp_type OperatorMinus::resultType(math_exp_type a, math_exp_type b) {
    switch (comb(a, b)) {
    case comb(MATH_INTEGER_VALUE, MATH_INTEGER_VALUE): return MATH_INTEGER_VALUE;
    case comb(MATH_DOUBLE_VALUE, MATH_DOUBLE_VALUE): return MATH_DOUBLE_VALUE;
    case comb(MATH_WILDCARD_VALUE, MATH_WILDCARD_VALUE): return MATH_WILDCARD_VALUE;
    case comb(MATH_INTEGER_VALUE, MATH_DOUBLE_VALUE): return MATH_DOUBLE_VALUE;
    case comb(MATH_WILDCARD_VALUE, MATH_DOUBLE_VALUE): return MATH_DOUBLE_VALUE;
    case comb(MATH_INTEGER_VALUE, MATH_WILDCARD_VALUE): return MATH_INTEGER_VALUE;
    default: return MATH_INVALID;
    }
}

math_exp_type OperatorMinus::resultType() {
    if (!this->left || !this->right) return MATH_INVALID;
    return resultType(this->left->resultType(), this->right->resultType());
}

DType *OperatorMinus::eval() {
    if (!this->left || !this->right) {
        return nullptr;
    }
    DType *res;
    DType *left = this->left->eval();
    DType *right = this->right->eval();

    res = nullptr;
    do {
        if (!left || !right) break;
        switch (left->id) {
        case MATH_INTEGER_VALUE:
            switch (right->id) {
            case MATH_INTEGER_VALUE:
                    res = new DTypeInt();
                    ((DTypeInt *) res)->val = ((DTypeInt *) left)->val - ((DTypeInt *) right)->val;
                    break;
            case MATH_DOUBLE_VALUE:
                    res = new DTypeDouble();
                    ((DTypeDouble *) res)->val = (double) ((DTypeInt *) left)->val - ((DTypeDouble *) right)->val;
                    break;
            }
            break;
        case MATH_DOUBLE_VALUE:
            switch (right->id) {
            case MATH_INTEGER_VALUE:
                    res = new DTypeDouble();
                    ((DTypeDouble *) res)->val = ((DTypeDouble *) left)->val - (double) ((DTypeInt *) right)->val ;
                    break;
            case MATH_DOUBLE_VALUE:
                    res = new DTypeDouble();
                    ((DTypeDouble *) res)->val = ((DTypeDouble *) left)->val - ((DTypeDouble *) right)->val ;
                    break;
            }
            break;
        }
    } while(0);

    delete left;
    delete right;
    return res;
}
Enter fullscreen mode Exit fullscreen mode

Operand

// dtype.h
class DType: public MathExprNode {
protected:
    DType();
public:
    virtual ~DType();
    math_exp_type id;
    virtual void setValue(void *v) = 0;
    virtual void setValue(DType *v) = 0;
    virtual math_exp_type resultType() = 0;
    virtual DType *eval() = 0;
    virtual DType *clone() = 0;
    static DType *factory(math_exp_type id, const char *v, DType *(*get_val)(const char *name)=nullptr);
};

class DTypeInt: public DType {
public:
    int val;
    DTypeInt(int v=0);
    DTypeInt(const char *v);
    virtual void setValue(void *v) override;
    virtual void setValue(DType *v) override;
    virtual math_exp_type resultType() override;
    virtual DType *eval() override;
    virtual DType *clone() override;
};

class DTypeDouble: public DType {
public:
    double val;
    DTypeDouble(double v=0.0);
    DTypeDouble(const char *v);
    virtual void setValue(void *v) override;
    virtual void setValue(DType *v) override;
    virtual math_exp_type resultType() override;
    virtual DType *eval() override;
    virtual DType *clone() override;
};

// dtype.cpp
DType::DType() {
    this->id = MATH_INVALID;
}
DType::~DType() {}

DType *DType::factory(math_exp_type id, const char *v, DType *(*get_val)(const char *name)) {
    switch (id) {
    case MATH_INTEGER_VALUE: return new DTypeInt(v);
    case MATH_DOUBLE_VALUE: return new DTypeDouble(v);
    case MATH_IDENTIFIER: return new DTypeVar(v, get_val);
    default: assert(0);
    }
}

DTypeInt::DTypeInt(int v) {
    this->id = MATH_INTEGER_VALUE;
    this->val = v;
}

DTypeInt::DTypeInt(const char *v): DTypeInt() {
    setValue((void *) v);
}

void DTypeInt::setValue(void *v) {
    this->val = atoi((char*) v);
}

void DTypeInt::setValue(DType *v) {
    this->val = dynamic_cast<DTypeInt *>(v)->val;
}

math_exp_type DTypeInt::resultType() {
    return MATH_INTEGER_VALUE;
}

DType *DTypeInt::eval() {
    return clone();
}

DType *DTypeInt::clone() {
    return new DTypeInt(this->val);
}

DTypeDouble::DTypeDouble(double v) {
    this->id = MATH_DOUBLE_VALUE;
    this->val = v;
}

DTypeDouble::DTypeDouble(const char* v): DTypeDouble() {
    setValue((void *) v);
}

void DTypeDouble::setValue(void *v) {
    this->val = atof((char*) v);
}

void DTypeDouble::setValue(DType *v) {
    this->val = dynamic_cast<DTypeDouble *>(v)->val;
}

math_exp_type DTypeDouble::resultType() {
    return MATH_DOUBLE_VALUE;
}

DType *DTypeDouble::eval() {
    return clone();
}

DType *DTypeDouble::clone() {
    return new DTypeDouble(this->val);
}
Enter fullscreen mode Exit fullscreen mode

Testing the Tree implementation

In the infix_to_postfix_test.c, you can add code to test the tree as follows.

// infix_to_postfix_test.c
#define DO_TEST(infix_array, test_tree) { \
    ...
    if (test_tree) { \
        MathExprTree *tree = new MathExprTree(postfix, len); \
        tree->inorderPrint(); \
        printf("\n"); \
        bool ok = tree->valid(); \
        printf("tree %s\n", ok ? "valid" : "in-valid"); \
        if (ok) { \
            printf("res: "); \
            DType *res = tree->eval(); \
            switch (res->id) { \
            case MATH_INTEGER_VALUE: printf("%d", ((DTypeInt *) res)->val); break; \
            case MATH_DOUBLE_VALUE: printf("%f", ((DTypeDouble *) res)->val); break; \
            case MATH_BOOL: printf("%s", ((DTypeBool *) res)->val ? "True" : "False"); break; \
            default: assert(0); \
            } \
            printf("\n"); \
            delete res; \
        } \
        delete tree; \
    } \
}

int main(void) {

    ...

    // 10 + 12 * 42.3 / (2 - 0.5)
    lex_data infix_array5[] = {
        {MATH_INTEGER_VALUE, 2, "10" },
        {MATH_SPACE, 1, " " },
        {MATH_PLUS, 1, "+" },
        {MATH_SPACE, 1, " " },
        {MATH_INTEGER_VALUE, 2, "12" },
        {MATH_SPACE, 1, " " },
        {MATH_MUL, 1, "*" },
        {MATH_SPACE, 1, " " },
        {MATH_DOUBLE_VALUE, 4, "42.3" },
        {MATH_DIV, 1, "/" },
        {MATH_BRACKET_START, 1, "(" },
        {MATH_INTEGER_VALUE, 1, "2" },
        {MATH_MINUS, 1, "-" },
        {MATH_DOUBLE_VALUE, 3, "0.5" },
        {MATH_BRACKET_END, 1, ")" },
    };

    // 1 + 2 * 3 
    lex_data infix_array6[] = {
        {MATH_INTEGER_VALUE, 1, "1" },
        {MATH_PLUS, 1, "+" },
        {MATH_INTEGER_VALUE, 1, "2" },
        {MATH_MUL, 1, "*" },
        {MATH_INTEGER_VALUE, 1, "3" },
    };

    // 1 + 2 * 3.0 
    lex_data infix_array7[] = {
        {MATH_INTEGER_VALUE, 1, "1" },
        {MATH_PLUS, 1, "+" },
        {MATH_INTEGER_VALUE, 1, "2" },
        {MATH_MUL, 1, "*" },
        {MATH_DOUBLE_VALUE, 3, "3.0" },
    };

    ...
     DO_TEST(infix_array5, 1);
     DO_TEST(infix_array6, 1);
     DO_TEST(infix_array7, 1);
}

Enter fullscreen mode Exit fullscreen mode

The rest operators like *, /, max, min, sqr, sqrt, pow, =, !=, <, <=, >, >= are for your exercise.

Top comments (1)

Collapse
 
christopher2093 profile image
Christopher

Great insight! Thanks