DEV Community

Drj
Drj

Posted on

Maximum Product of Splitted Binary Tree

Problem statement

Given the root of a binary tree, split the binary tree into two subtrees by removing one edge such that the product of the sums of the subtrees is maximized.

Return the maximum product of the sums of the two subtrees. Since the answer may be too large, return it modulo 109 + 7.

Note that you need to maximize the answer before taking the mod and not after taking it.

Link to the question on Leetcode

Example:1

Input: root = [1,2,3,4,5,6]
Output: 110
Explanation: Remove the red edge and get 2 binary trees with sum 11 and 10. Their product is 110 (11*10)

Example

Approach

The question says that we have to split the binary tree to sub tree where the product of sum of two subtrees are maximized. Splitting the binary tree means removing one edge from binary tree.

We can solve this problem by pre calculating the sum of subtree for each node which will help us in finding the maximum product of sum of sub trees.

So the idea here is we have to find out the maximum value of subtree multiplied with total tree sum minus subtree sum which gives the product of sum of subtrees.

Lets solve this with example [2,3,9,10,7,8,6,5,4,11,1]. This is the given tree. Below is the tree representation of above input.

Input

First lets calculate the sum of subtree for each node. Sum of all the subtrees are as below:

All possible subtrees

Now lets try to remove the edge between root node "2" and its left child "3" , we get two subtrees . we can get the sum of two subtrees directly from our pre calculated values . The sum of two subtrees can be calculated as in below image.

Subtree sum

We can directly get the sum of subtree1 from our pre calculated values and sum of other subtree will total tree sum minus subtree1 sum. Find the product of these two subtrees.

In the same way we have to try removing all the edges one by one and find the maximum product of sum of subtrees.

Note:To pre calculate all the subtrees sum use post order traversal of binary tree and store all the values in a dictionary.

Code

import math

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def maxProduct(self, root: Optional[TreeNode]) -> int:
        if root is None:
            return 0
        subtree_sum = dict()
        self.subtree_sum(root,subtree_sum)
        root_subtree_sum = subtree_sum[root]
        maximum = 0
        for subtree in subtree_sum:
            maximum = max(maximum , subtree_sum[subtree] * (root_subtree_sum - subtree_sum[subtree]))
        return int(maximum%(math.pow(10,9)+7))

    def subtree_sum(self,root,subtree_sum):
        if root is None:
            return
        self.subtree_sum(root.left,subtree_sum)
        self.subtree_sum(root.right,subtree_sum)
        left_sum = 0
        right_sum = 0
        if root.left:
            left_sum = subtree_sum.get(root.left)
        if root.right:
            right_sum = subtree_sum.get(root.right)
        subtree_sum[root] = left_sum + right_sum + root.val
Enter fullscreen mode Exit fullscreen mode

Top comments (0)