Leetcode 308 Solution

This article provides solution to leetcode question 308 (range-sum-query-2d-mutable)

https://leetcode.com/problems/range-sum-query-2d-mutable

Solution

class SegmentTree
{
    int m_size;
    vector<int> nodes;
public:
    void init(int n)
    {
        m_size = n;
        nodes.resize(2 * pow(2, ceil(log2(n))));
    }

    int getsize()
    {
        return m_size;
    }

    int build(vector<int>& nums, int i, int l, int r)
    {
        if (l == r)
            return nodes[i] = nums[l];

        int m = (l + r) / 2;
        int left_sum = build(nums, 2 * i + 1, l, m);
        int right_sum = build(nums, 2 * i + 2, m + 1, r);
        return nodes[i] = left_sum + right_sum;
    }

    int update(int i, int left, int right, int pos, int val)
    {
        if (left == right)
        {
            int res = val - nodes[i];
            nodes[i] = val;
            return res;
        }

        int m = (left + right) / 2;

        int diff = pos <= m ? update(2 * i + 1, left, m, pos, val) : update(2 * i + 2, m + 1, right, pos, val);
        nodes[i] += diff;
        return diff;
    }

    int query(int i, int left, int right, int l, int r)
    {
        if (left == l && right == r)
            return nodes[i];

        int m = (left + right) / 2;

        int res = 0;
        if (l <= m)
            res += query(2 * i + 1, left, m, l, min(m, r));
        if (r > m)
            res += query(2 * i + 2, m + 1, right, max(l, m + 1), r);
        return res;
    }
};

class NumMatrix {
    vector<SegmentTree> segment_trees;
    int m;
    int n;
public:
    NumMatrix(vector<vector<int>> matrix) {
        if (matrix.size() == 0 || matrix[0].size() == 0)
            return;

        m = matrix.size();
        n = matrix[0].size();

        for (int i = 0; i < m; i++)
        {
            SegmentTree segment_tree;
            segment_tree.init(n);
            segment_tree.build(matrix[i], 0, 0, n - 1);
            segment_trees.push_back(segment_tree);
        }
    }

    void update(int row, int col, int val) {
        if (m == 0 || n == 0)
            return;

        segment_trees[row].update(0, 0, n - 1, col, val);
    }

    int sumRegion(int row1, int col1, int row2, int col2) {
        if (m == 0 || n == 0)
            return 0;

        int sum = 0;
        for (int i = row1; i <= row2; i++)
            sum += segment_trees[i].query(0, 0, n - 1, col1, col2);
        return sum;
    }
};

/**
 * Your NumMatrix object will be instantiated and called as such:
 * NumMatrix obj = new NumMatrix(matrix);
 * obj.update(row,col,val);
 * int param_2 = obj.sumRegion(row1,col1,row2,col2);
 */