Segment Tree Implementation in Java

Segment tree is just another binary tree which supporting operations such as finding the minimum, maximum or sum of a given range of index in the array. The tree implementation is similar to the binary tree implementation of heap over an array.

Based on the requirements the current code for range minimum query can be modified for the range maximum query or range sum query.

The api for the given class looks like:

public class SegmentTree {
    public SegmentTree(int[] A);
    public int query(Point interval);
    public void update(int index, int value);
}

Complete Code:


import java.awt.*;

/**
 * Created by dhruv.pancholi on 17/01/16.
 * @author Dhruv Pancholi
 * @version 1.0
 */
public class SegmentTree {

    private int[] oArray;
    private int[] stArray;

    /**
     * Class which builds segment tree for the input array
     *
     * @param oArray original array
     */
    public SegmentTree(int[] oArray) {
        this.oArray = oArray;

        // The required length of the segment tree array
        // same as heap array
        int len = (int) Math.pow(2, Math.ceil(log2(oArray.length) + 1));
        stArray = new int[len];

        buildSegmentTree(0, new Point(0, oArray.length - 1));
    }

    /**
     * Builds a segment tree object for Range minimum query
     * Runtime: O(N*logN)
     *
     * @param node     index of the stArray which represents an interval
     * @param interval the interval which is represented by the node
     */
    private void buildSegmentTree(int node, Point interval) {
        if (interval.x == interval.y) {
            stArray[node] = interval.x;
            return;
        }
        Point children = new Point(2 * node + 1, 2 * node + 2);
        buildSegmentTree(children.x, new Point(interval.x, (interval.x + interval.y) / 2));
        buildSegmentTree(children.y, new Point((interval.x + interval.y) / 2 + 1, interval.y));
        stArray[node] = (oArray[stArray[children.x]] <= oArray[stArray[children.y]]) ? stArray[children.x] : stArray[children.y];
    }

    /**
     * Update the value of the array and the corresponding tree
     *
     * @param node      Traversal node index
     * @param cInterval Interval in which the index lies
     * @param index     The index value to be updated
     * @param value     The value with which to be updated
     */
    private void update(int node, Point cInterval, int index, int value) {
        if (cInterval.x == cInterval.y) {
            oArray[index] = value;
            return;
        }

        if (index <= (cInterval.x + cInterval.y) / 2) {
            update(2 * node + 1, new Point(cInterval.x, (cInterval.x + cInterval.y) / 2), index, value);
        } else {
            update(2 * node + 2, new Point((cInterval.x + cInterval.y) / 2 + 1, cInterval.y), index, value);
        }
        stArray[node] = (oArray[stArray[2 * node + 1]] <= oArray[stArray[2 * node + 2]]) ? stArray[2 * node + 1] : stArray[2 * node + 2];
    }

    /**
     * Wrapper method for update
     * Update API
     * Runtime: O(lgN)
     * @param index Index to be updated
     * @param value The value to be updated with
     */
    public void update(int index, int value) {
        update(0, new Point(0, oArray.length - 1), index, value);
    }

    /**
     * @param node Parent of the given sub-segment tree
     * @param cInterval Current interval
     * @param rInterval Required interval
     * @return index of the minimum element in the given range
     */
    private int query(int node, Point cInterval, Point rInterval) {
        if (cInterval.y < rInterval.x || cInterval.x > rInterval.y) {
            return -1;
        }
        if (cInterval.x >= rInterval.x && cInterval.y <= rInterval.y) {
            return stArray[node];
        }

        Point children = new Point(2 * node + 1, 2 * node + 2);
        Point result = new Point();
        result.x = query(children.x, new Point(cInterval.x, (cInterval.x + cInterval.y) / 2), rInterval);
        result.y = query(children.y, new Point((cInterval.x + cInterval.y) / 2 + 1, cInterval.y), rInterval);

        if (result.x == -1) {
            return result.y;
        }
        if (result.y == -1) {
            return result.x;
        }

        return (result.x <= result.y) ? result.x : result.y;
    }

    /**
     * Wrapper method to query from outside
     * Runtime: O(lgN)
     *
     * @param rInterval interval in which minimum is required
     * @return range minimum index
     */
    public int query(Point rInterval) {
        return query(0, new Point(0, oArray.length - 1), rInterval);
    }

    /**
     * Default is to the base e
     *
     * @param x
     * @return log to the base 2
     */
    private double log2(int x) {
        return Math.log(x) / Math.log(2);
    }

    public static void main(String[] args) {
        int[] A = new int[]{8, 7, 3, 9, 5, 1, 10};
        SegmentTree segmentTree = new SegmentTree(A);
        segmentTree.update(6, -1);
        System.out.println(segmentTree.query(new Point(0, 6)));
    }
}