import {
    Light,
    PositionedLight,
    Location,
    LightType,
    TokenAudio,
    Campaign,
    Token,
    LocationLevel,
} from "../../../store";
import { Mesh, BufferGeometry, Vector3, ShaderMaterial, PointLight } from "three";
import { intersect, IGrid, pointsEqual, ILocalGrid } from "../../../grid";
import { Point, Rect, Position, LocalPixelPosition, LocalRect } from "../../../position";
import { getObstructionPolygon, ObstructingAnnotation } from "../../../annotations";
import { Paths } from "js-angusj-clipper";
import { RootState } from "@react-three/fiber";
import { commonShaderFunctions, commonShaderUniforms } from "../common";
import { DeepPartial } from "../../../common";

export interface EndPoint extends Point {
    beginsSegment?: boolean;
    segment?: Segment;
    angle: number;
}

export interface Segment {
    p1: EndPoint;
    p2: EndPoint;
    d: number;
}

export interface LightCache<T = {}> {
    source: PositionedLight & T;
    mesh: Mesh;
    template: LightShaderTemplate<T>;
    visibilityPolygon?: Paths;
}

interface LightShaderRenderState<T> {
    context: RootState;
    material: ShaderMaterial;
    mesh: Mesh;
    light: Required<T & PositionedLight>;
    grid: IGrid;
    setPosition: (pos: LocalPixelPosition) => void;
    pointLight: PointLight | null;
}

export interface LightShaderTemplate<T = {}> {
    id: LightType;
    label: string;
    tags: string[];
    shader: string;
    defaults?: Required<T> & Partial<Pick<Light, Exclude<keyof Light, "type">>>;
    onInit?: (material: ShaderMaterial) => void;
    onBeforeRender?: (state: LightShaderRenderState<T>) => void;
    //updateUniforms?: (material: ShaderMaterial, light: Required<T & PositionedLight>) => void;

    sound?: TokenAudio;
}

export function fract(n: number) {
    return n - Math.trunc(n);
}

function segment(x1: number, y1: number, x2: number, y2: number): Segment {
    const p1 = { x: x1, y: y1, angle: 0 } as EndPoint;
    const p2 = { x: x2, y: y2, angle: 0 } as EndPoint;
    const segment = {
        p1: p1,
        p2: p2,
        d: 0,
    };

    p1.segment = segment;
    p2.segment = segment;
    return segment;
}

function calculateEndPointAngles(source: Point, segment: Segment) {
    const { x, y } = source;
    const dx = 0.5 * (segment.p1.x + segment.p2.x) - x;
    const dy = 0.5 * (segment.p1.y + segment.p2.y) - y;

    segment.d = dx * dx + dy * dy;
    segment.p1.angle = Math.atan2(segment.p1.y - y, segment.p1.x - x);
    segment.p2.angle = Math.atan2(segment.p2.y - y, segment.p2.x - x);
}

function setSegmentBeginning(segment: Segment) {
    let dAngle = segment.p2.angle - segment.p1.angle;

    if (dAngle <= -Math.PI) {
        dAngle += 2 * Math.PI;
    }

    if (dAngle > Math.PI) {
        dAngle -= 2 * Math.PI;
    }

    segment.p1.beginsSegment = dAngle > 0;
    segment.p2.beginsSegment = !segment.p1.beginsSegment;
}

function processSegments(source: Point, segments: Segment[]) {
    for (let i = 0; i < segments.length; i++) {
        let segment = segments[i];
        calculateEndPointAngles(source, segment);
        setSegmentBeginning(segment);
    }
}

function segmentsFromRect(rect: Rect) {
    const nwx = rect.x;
    const nwy = rect.y;
    const swx = rect.x;
    const swy = rect.y + rect.height;
    const nex = rect.x + rect.width;
    const ney = rect.y;
    const sex = rect.x + rect.width;
    const sey = rect.y + rect.height;
    return [
        segment(nwx, nwy, nex, ney),
        segment(nwx, nwy, swx, swy),
        segment(nex, ney, sex, sey),
        segment(swx, swy, sex, sey),
    ];
}

export function loadSegments(
    campaign: Campaign,
    location: Location,
    level: LocationLevel,
    grid: ILocalGrid,
    bounds: LocalRect,
    obstructions: ObstructingAnnotation[],
    tokenOverrides?: { [id: string]: DeepPartial<Token> }
) {
    const segments: Segment[] = segmentsFromRect(bounds);

    for (let i = 0; i < obstructions.length; i++) {
        const obstruction = getObstructionPolygon(obstructions[i], campaign, location, grid, tokenOverrides);
        for (let i = 1; i < obstruction.points.length; i++) {
            segments.push(
                segment(
                    obstruction.points[i - 1].x + obstruction.pos.x,
                    obstruction.points[i - 1].y + obstruction.pos.y,
                    obstruction.points[i].x + obstruction.pos.x,
                    obstruction.points[i].y + obstruction.pos.y
                )
            );
        }

        if (obstruction.points.length > 2 && obstruction.isClosed) {
            segments.push(
                segment(
                    obstruction.points[obstruction.points.length - 1].x + obstruction.pos.x,
                    obstruction.points[obstruction.points.length - 1].y + obstruction.pos.y,
                    obstruction.points[0].x + obstruction.pos.x,
                    obstruction.points[0].y + obstruction.pos.y
                )
            );
        }
    }

    // We need to make sure that any intersections are removed and split into more segments, as the visibility algorithm doesn't
    // always work correctly if it does - it can't tell which segment is in front of another if they cross.
    for (let i = 0; i < segments.length; i++) {
        let a = segments[i];
        for (let j = i + 1; j < segments.length; j++) {
            const b = segments[j];
            const intersectPoint = intersect(a.p1, a.p2, false, b.p1, b.p2, false);
            if (intersectPoint) {
                if (pointsEqual(intersectPoint, b.p1) || pointsEqual(intersectPoint, b.p2)) {
                    // The intersection point is one of the ends, we don't need to worry about it.
                } else {
                    const newa = segment(a.p1.x, a.p1.y, intersectPoint.x, intersectPoint.y);

                    // The segments intersected, so we need to split them up.
                    segments.splice(
                        j,
                        1,
                        segment(b.p1.x, b.p1.y, intersectPoint.x, intersectPoint.y),
                        segment(intersectPoint.x, intersectPoint.y, b.p2.x, b.p2.y)
                    );
                    segments.splice(i, 1, newa, segment(intersectPoint.x, intersectPoint.y, a.p2.x, a.p2.y));

                    // When we find an intersection between two segments, we split each of them into 2, replacing the original
                    // segments in the segments array with 2 segments each.
                    // We don't want to redo the bits we've already done that didn't hit any intersection, so we continue
                    // the current pass with just the first part of the segment. The second part will get its own run as it is
                    // at i + 1.
                    j = j + 2;
                    a = newa;
                    break;
                }
            }
        }
    }

    return segments;
}

function loadEndpoints(source: Point, segments: Segment[]) {
    processSegments(source, segments);
    const endpoints = segments.flatMap(o => [o.p1, o.p2]);
    return endpoints;
}

function endpointCompare(pointA: EndPoint, pointB: EndPoint) {
    if (pointA.angle > pointB.angle) {
        return 1;
    }

    if (pointA.angle < pointB.angle) {
        return -1;
    }

    if (!pointA.beginsSegment && pointB.beginsSegment) {
        return 1;
    }

    if (pointA.beginsSegment && !pointB.beginsSegment) {
        return -1;
    }

    return 0;
}

function leftOf(segment: Segment, point: Point) {
    const cross =
        (segment.p2.x - segment.p1.x) * (point.y - segment.p1.y) -
        (segment.p2.y - segment.p1.y) * (point.x - segment.p1.x);
    return cross < 0;
}

function interpolate(pointA: Point, pointB: Point, f: number) {
    return {
        x: pointA.x * (1 - f) + pointB.x * f,
        y: pointA.y * (1 - f) + pointB.y * f,
    };
}

function segmentInFrontOf(segmentA: Segment, segmentB: Segment, relativePoint: Point) {
    const A1 = leftOf(segmentA, interpolate(segmentB.p1, segmentB.p2, 0.01));
    const A2 = leftOf(segmentA, interpolate(segmentB.p2, segmentB.p1, 0.01));
    const A3 = leftOf(segmentA, relativePoint);
    const B1 = leftOf(segmentB, interpolate(segmentA.p1, segmentA.p2, 0.01));
    const B2 = leftOf(segmentB, interpolate(segmentA.p2, segmentA.p1, 0.01));
    const B3 = leftOf(segmentB, relativePoint);

    if (B1 === B2 && B2 !== B3) {
        return true;
    }

    if (A1 === A2 && A2 === A3) {
        return true;
    }

    if (A1 === A2 && A2 !== A3) {
        return false;
    }

    if (B1 === B2 && B2 === B3) {
        return false;
    }

    return false;
}

function getTrianglePoints<T extends Position>(
    origin: T,
    angle1: number,
    angle2: number,
    segment: Segment
): [T, T] | undefined {
    if (angle1 === angle2) {
        return undefined;
    }

    const p1 = origin;
    const p2 = { type: origin.type, x: origin.x + Math.cos(angle1), y: origin.y + Math.sin(angle1) } as T;
    const p3 = { type: origin.type, x: 0, y: 0 } as T;
    const p4 = { type: origin.type, x: 0, y: 0 } as T;

    if (segment) {
        p3.x = segment.p1.x;
        p3.y = segment.p1.y;
        p4.x = segment.p2.x;
        p4.y = segment.p2.y;
    } else {
        p3.x = origin.x + Math.cos(angle1) * 200;
        p3.y = origin.y + Math.sin(angle1) * 200;
        p4.x = origin.x + Math.cos(angle2) * 200;
        p4.y = origin.y + Math.sin(angle2) * 200;
    }

    const pBegin = intersect(p3, p4, true, p1, p2, true);
    if (!pBegin) {
        // The lines are exactly parallel, so as rays they overlap.
        console.log(`Failure to intersect: ${p3.x},${p3.y}->${p4.x},${p4.y} and ${p1.x},${p1.y}->${p2.x},${p2.y}`);
        return undefined;
    }

    p2.x = origin.x + Math.cos(angle2);
    p2.y = origin.y + Math.sin(angle2);

    const pEnd = intersect(p3, p4, true, p1, p2, true);
    if (!pEnd) {
        console.log(`Failure to intersect: ${p3.x},${p3.y}->${p4.x},${p4.y} and ${p1.x},${p1.y}->${p2.x},${p2.y}`);
        return undefined;
    }

    // We know these lines will intersect.
    return [pBegin as T, pEnd as T];
}

function calculateVisibileSegments<T extends Position>(origin: T, endpoints: EndPoint[]) {
    let openSegments: Segment[] = [];
    let output: [Point, Point][] = [];
    let beginAngle = 0;

    endpoints.sort(endpointCompare);

    for (let pass = 0; pass < 2; pass++) {
        for (let i = 0; i < endpoints.length; i++) {
            let endpoint = endpoints[i];
            let openSegment = openSegments[0];

            if (endpoint.beginsSegment) {
                let index = 0;
                let segment = openSegments[index];
                while (segment && segmentInFrontOf(endpoint.segment as Segment, segment, origin)) {
                    index++;
                    segment = openSegments[index];
                }

                if (!segment) {
                    openSegments.push(endpoint.segment as Segment);
                } else {
                    openSegments.splice(index, 0, endpoint.segment as Segment);
                }
            } else {
                let index = openSegments.indexOf(endpoint.segment as Segment);
                if (index > -1) {
                    openSegments.splice(index, 1);
                }
            }

            if (openSegment !== openSegments[0]) {
                if (pass === 1) {
                    let trianglePoints = getTrianglePoints(origin, beginAngle, endpoint.angle, openSegment);
                    if (trianglePoints) {
                        output.push(trianglePoints);
                    }
                }

                beginAngle = endpoint.angle;
            }
        }
    }

    return output;
}

export function getSegmentsForSource(source: LocalPixelPosition, segments: Segment[]) {
    const endpoints = loadEndpoints(source, segments);
    const visibility = calculateVisibileSegments(source, endpoints);
    return visibility;
}

export function updateGeometryForSegments(mesh: Mesh, source: LocalPixelPosition, segments: [Point, Point][]) {
    const geometry = new BufferGeometry();

    // Visibility contains a set of triangles representing what the user could possibly see, where
    // the third point of the triangle is source[i].
    // TODO: Try reversing the order of the points so we don't need to use DoubleSide on the materials.
    const points: Vector3[] = [];
    for (let v = 0; v < segments.length; v++) {
        const triangle = segments[v];
        if (triangle[0] == null || triangle[1] == null) {
            console.warn("Visibility triangle contains null point, ignoring.");
            segments.splice(v, 1);
            v--;
            continue;
        }

        points.push(
            new Vector3(source.x, -source.y, 0),
            new Vector3(triangle[1].x, -triangle[1].y, 0),
            new Vector3(triangle[0].x, -triangle[0].y, 0)
        );
        // const vertexIndex = v * 3;
        // geometry.faces.push(new Face3(vertexIndex, vertexIndex + 1, vertexIndex + 2));
    }

    geometry.setFromPoints(points);

    mesh.geometry = geometry;

    return segments;
}

export function updateGeometryForSource(mesh: Mesh, source: LocalPixelPosition, segments: Segment[]) {
    const geometry = new BufferGeometry();

    const endpoints = loadEndpoints(source, segments);
    const visibility = calculateVisibileSegments(source, endpoints);

    // Visibility contains a set of triangles representing what the user could possibly see, where
    // the third point of the triangle is source[i].
    // TODO: Try reversing the order of the points so we don't need to use DoubleSide on the materials.
    const points: Vector3[] = [];
    for (let v = 0; v < visibility.length; v++) {
        const triangle = visibility[v];
        if (triangle[0] == null || triangle[1] == null) {
            console.warn("Visibility triangle contains null point, ignoring.");
            visibility.splice(v, 1);
            v--;
            continue;
        }

        points.push(
            new Vector3(source.x, -source.y, 0),
            new Vector3(triangle[1].x, -triangle[1].y, 0),
            new Vector3(triangle[0].x, -triangle[0].y, 0)
        );
        // const vertexIndex = v * 3;
        // geometry.faces.push(new Face3(vertexIndex, vertexIndex + 1, vertexIndex + 2));
    }

    geometry.setFromPoints(points);

    mesh.geometry = geometry;

    return visibility;
}

export function createLightShader(code: string, uniforms?: string) {
    return `
${commonShaderUniforms}

uniform vec3 u_color;
uniform vec2 u_position;
uniform float u_near;
uniform float u_far;
uniform float u_brightness;
uniform sampler2D u_sdf;
uniform float u_sdfSize;
uniform bool u_useSdf;

${uniforms ?? ""}

${commonShaderFunctions}

float getRayMarch(in vec2 rayOrigin, in vec2 rayDestination, float hardness, float maxHardness) {
    float rayLength = distance(rayOrigin, rayDestination);
    vec2 rayDirection = (rayDestination - rayOrigin) / rayLength;
    float rayProgress = 0.0001;
    float ph = 1e10;
    float shadow = 1e10;
    for (int i = 0; i < 64; i++) {
        if (rayProgress >= rayLength) {
            // We hit the light! This pixel is not in shadow.
            return hardness > maxHardness ? 1.0 : clamp(shadow, 0.0, 1.0);
        }
    
        vec2 marchPoint = rayOrigin + (rayProgress * rayDirection);
        float sceneDist = getDistanceValue(u_sdf, u_sdfSize, marchPoint);
        if (sceneDist <= 0.0) {
            // We hit a shape! This pixel is in shadow.
            return 0.0;
        }

        // Slightly more complicated than just using hardness * (sceneDist / rayProgress), but
        // it does smooth out some artifacts. From https://iquilezles.org/articles/rmshadows/ and
        // https://www.rykap.com/2020/09/23/distance-fields/
        float y = sceneDist * sceneDist / (2.0*ph);
        float d = sqrt(sceneDist*sceneDist-y*y);
        shadow = min(shadow, hardness * d/max(0.0,rayProgress-y));
        ph = sceneDist;

        rayProgress += sceneDist;
    }

    // Ray-marching took too many steps.
    return 0.0;
}

void main() {
    // gl_FragCoord is in canvas coords and has an origin of bottom left, we need top left.
    vec2 localFragCoord = toLocalPoint(fragCoordToScreenPoint(gl_FragCoord.xy));

    // Ramp up the hardness the closer the light is to an obstruction, helps avoid artefacts and just looks better.
    float ray = 1.0;
    if (u_useSdf) {
        float lightDist = getDistanceValue(u_sdf, u_sdfSize, u_position);
        float hardnessFactor = clamp(25.0 - lightDist, 1.0, 25.0);
        ray = getRayMarch(localFragCoord, u_position, hardnessFactor * 20.0, 400.0);
    }

    ${code}

    gl_FragColor = vec4(gl_FragColor.rgb, gl_FragColor.a * ray);    
}`;
}

export const defaultLightTags = ["light"];
