import TinyQueue from "tinyqueue";

interface AStarOptions<T, U = T> {
    /**
     * Will be passed to the other functions.
     */
    state?: any;

    /**
     * Returns the value to use as the identity for the node. This is used to identify the node in maps
     * and compare its equality with the identities of other nodes.
     * Default behaviour is to use the node itself as the identity.
     */
    identity?: (node: T, state?: any) => any;

    /**
     * Returns a value that indicates the cost of traversing from node a to node b.
     * Values under 1 indicate that the traversal is impossible. Values greater than 1
     * indicate that the traversal is more costly than baseline.
     * Default behaviour is to return 1 for every traversal.
     */
    traversalCost?: (
        a: T,
        b: T,
        from: T,
        to: U,
        state?: any,
        nodeState?: any
    ) => number | { cost: number; compareCost?: number; nodeState?: any };

    /**
     * The heuristic to use for guessing the remaining traversal cost between a and b.
     * Default behaviour is to use 0 (no heuristic). This will significantly affect the performance of
     * finding paths - a heuristic function should be provided wherever possible.
     */
    heuristic?: (a: T, b: U, state?: any) => number;

    /**
     * Gets all of the neighbour nodes (linked nodes) of the specified node.
     */
    getNeighbours: (node: T, state?: any) => T[];
}

interface NodeSearchState<T> {
    node: T;
    parent?: NodeSearchState<T>;
    f: number;
    h: number;
    g: number;
    cost: number;
    isClosed: boolean;
    isVisited: boolean;
    state?: any;
}

export interface PathResult<T> {
    path: T[];
    cost: number;
    pathCost: number[];
    pathState?: any[];
}

export function findPathAStar<T extends object>(from: T, to: T, options: AStarOptions<T>): PathResult<T> | undefined;
export function findPathAStar<T extends object>(
    from: T,
    to: (node: T) => boolean,
    options: AStarOptions<T, (node: T) => boolean>
): PathResult<T> | undefined;
export function findPathAStar<T extends object>(
    from: T,
    to: T | ((node: T) => boolean),
    options: AStarOptions<T> | AStarOptions<T, (node: T) => boolean>
): PathResult<T> | undefined {
    const nodeStateMap = new Map<any, NodeSearchState<T>>();

    const identity = options.identity ?? ((node: T, state?: any) => node);
    const traversalCost = options.traversalCost ?? ((a: T, b: T, state?: any, nodeState?: any) => 1);
    const heuristic = options.heuristic ?? ((a: T, b: T | ((to: T) => boolean), state?: any) => 0);

    const isToFunction = typeof to === "function";

    const startNode = createSearchState(from);
    startNode.h = heuristic(from, to as T & ((node: T) => boolean), options.state);

    nodeStateMap.set(identity(startNode.node, options.state), startNode);
    const openHeap = new TinyQueue([startNode], (a, b) => a.f - b.f);

    while (openHeap.length > 0) {
        // Grab the lowest f to process next.
        const currentNode = openHeap.pop() as NodeSearchState<T>;

        if (isToFunction) {
            if (to(currentNode.node)) {
                return pathTo(currentNode);
            }
        } else if (identity(currentNode.node, options.state) === identity(to, options.state)) {
            return pathTo(currentNode);
        }

        // Move current node from open to closed, then process each of its neighbours.
        currentNode.isClosed = true;
        const neighbours = options.getNeighbours(currentNode.node, options.state);
        for (let i = 0; i < neighbours.length; i++) {
            const neighbour = neighbours[i];
            const neighbourIdentity = identity(neighbour, options.state);
            let neighbourState = nodeStateMap.get(neighbourIdentity);
            if (!neighbourState) {
                neighbourState = createSearchState(neighbour);
                nodeStateMap.set(neighbourIdentity, neighbourState);
            }

            if (neighbourState.isClosed) {
                continue;
            }

            const cost = traversalCost(
                currentNode.node,
                neighbour,
                from,
                to as T & ((node: T) => boolean),
                options.state,
                currentNode.state
            );
            let actualCost: number;
            let compareCost: number;
            let newState: any;
            if (typeof cost === "object") {
                actualCost = cost.cost;
                compareCost = cost.compareCost ?? actualCost;
                newState = cost.nodeState;
            } else {
                actualCost = cost;
                compareCost = cost;
            }

            if (compareCost < 1) {
                continue;
            }

            // Check to see if this is the best way to this node so far.
            const g = currentNode.g + compareCost;
            const isVisited = neighbourState.isVisited;
            if (!isVisited || g < neighbourState.g) {
                // Found the best path to this node so far, update its search state.
                neighbourState.isVisited = true;
                neighbourState.parent = currentNode;
                neighbourState.g = g;
                neighbourState.cost = currentNode.cost + actualCost;
                neighbourState.state = newState;
                if (!isVisited) {
                    neighbourState.h = heuristic(neighbour, to as T & ((node: T) => boolean), options.state);
                }

                neighbourState.f = g + neighbourState.h;

                if (!isVisited) {
                    openHeap.push(neighbourState);
                } else {
                    // This node has changed, so we need to update its position in the sorted heap.
                    // TODO: Pull this implementation in house, so we don't have to hack it like this?
                    // If we do, we can hack it so that we don't need to do this indexOf here and can keep track of the index manually.
                    let heapIndex = openHeap.data.indexOf(neighbourState);
                    openHeap["_down"](heapIndex);
                    openHeap["_up"](heapIndex);
                }
            }
        }
    }

    return undefined;
}

function createSearchState<T>(node: T): NodeSearchState<T> {
    return {
        node: node,
        f: 0,
        h: 0,
        g: 0,
        cost: 0,
        isClosed: false,
        isVisited: false,
    };
}

function pathTo<T>(node: NodeSearchState<T>): PathResult<T> {
    let curr = node;
    const path: T[] = [];
    let states: any[] | undefined;
    let costs: number[] = [];
    let count = 0;
    while (curr.parent) {
        path.unshift(curr.node);
        costs.unshift(curr.cost);

        if (curr.state && !states) {
            states = new Array(count);
        }

        if (states) {
            states.unshift(curr.state);
        }

        count++;
        curr = curr.parent;
    }

    path.unshift(curr.node);

    return { path: path, cost: node.cost, pathCost: costs, pathState: states };
}
