import * as d3 from "d3";
import { addGroupForAxes, CLS_MAIN, plotAxes, type PlotProps } from "./LinearSystemPlot";

export const CLS_FUNC = "func";
export const CLS_INTERVALS = "intervals";

export function addGroupFor(contentGroup: d3.Selection<SVGGElement, unknown, null, any>, key: string) {
    contentGroup.append("g").attr("class", key);
}

export function selectGroupFunc(): d3.Selection<SVGGElement, unknown, HTMLElement, any> {
    return d3.select(`g.${CLS_FUNC}`);
}

export function selectGroupIntervals(): d3.Selection<SVGGElement, unknown, HTMLElement, any> {
    return d3.select(`g.${CLS_INTERVALS}`);
}

export interface PlotPropsExt extends PlotProps {
    displayMode: "bisection" | "regula-falsi";
}

export function initPlot(
    svgElement: SVGSVGElement | null,
    func: (x: number) => number,
    interval: [number, number],
    //points: number[][],
    plotProps: PlotPropsExt,
) {
    if (!svgElement) return;
    console.log("Init plot");
    // Clear the previous graph
    const svg = d3.select(svgElement);
    svg.selectAll("*").remove();

    // Top-level content group
    const mainGroup = svg.append("g").attr("class", CLS_MAIN);
    addGroupForAxes(mainGroup);
    addGroupFor(mainGroup, CLS_FUNC);
    addGroupFor(mainGroup, CLS_INTERVALS);

    // Create the zoom behavior
    const zoom = d3.zoom().on("zoom", (event) => {
        // Update the scales' domains based on the zoom event
        const newXScale = event.transform.rescaleX(plotProps.xScale);
        const newYScale = event.transform.rescaleY(plotProps.yScale);

        plotProps.currentXScale = newXScale;
        plotProps.currentYScale = newYScale;

        // Update the graph using the new scales
        zoomGraph(
            {
                width: plotProps.width,
                height: plotProps.height,
                margin: plotProps.margin,
                xScale: newXScale,
                yScale: newYScale,
                currentXScale: newXScale,
                currentYScale: newYScale,
                displayMode: plotProps.displayMode,
            },
            func,
        );
    });

    svg.call(zoom as any);

    redrawGraph(svgElement, plotProps, func, [interval]);
}

// Function to redraw the graph with updated scales
export function redrawGraph(
    svgElement: SVGSVGElement,
    plotProps: PlotPropsExt,
    func: (x: number) => number,
    intervals: [number, number][],
) {
    plotAxes(plotProps);
    plotFunction(func, plotProps);
    //plotPoints(points, plotProps);
    plotIntervalMarkers(intervals, func, plotProps);
}

// Function to redraw the graph with updated scales
export function zoomGraph(plotProps: PlotPropsExt, func: (x: number) => number) {
    plotAxes(plotProps);
    plotFunction(func, plotProps);
    repositionIntervalsOnZoom(func, plotProps);
}

//
// Plot the function graph
//
export function plotFunction(func: (x: number) => number, plotProps: PlotPropsExt) {
    const g = selectGroupFunc();
    g.selectAll("*").remove();

    const [xScale, yScale] = [plotProps.currentXScale, plotProps.currentYScale];
    const [xmin, xmax] = xScale.domain();

    // Plot the function curve
    const x_coords = d3.range(xmin, xmax, (xmax - xmin) / 100);
    const points = x_coords.map((x) => [x, func(x)]);

    // Plot the function curve
    const line = d3
        .line<number[]>()
        .x((d) => xScale(d[0]))
        .y((d) => yScale(d[1]));

    g.append("path")
        .datum(points)
        //.attr("clip-path", "url(#chart-area)")
        .attr("fill", "none")
        .attr("stroke", "teal")
        .attr("stroke-width", 2)
        .attr("d", line);
}

export function plotIntervalMarkers(
    intervals: [number, number][], // Array of intervals
    func: (x: number) => number, // Function whose root we are trying to find
    plotProps: PlotPropsExt,
) {
    const [xScale, yScale] = [plotProps.currentXScale, plotProps.currentYScale];
    const [ymin, ymax] = yScale.domain();

    const g = selectGroupIntervals(); // Assume this selects or creates an SVG group element

    // Flatten the intervals array to get an array of boundary points
    const boundaries = intervals.flatMap(([left, right], i) => {
        const isLastInterval = i === intervals.length - 1;
        return [
            { boundary: left, isLastInterval, endPoint: func(left) >= 0 ? ymax : ymin },
            { boundary: right, isLastInterval, endPoint: func(right) >= 0 ? ymax : ymin },
        ];
    });

    // Remove existing lines and dots
    g.selectAll("line.interval").remove();
    g.selectAll("circle.interval-dot").remove();

    // Bind boundary data for the lines
    const lines = g.selectAll("line.interval").data(boundaries);

    lines
        .enter()
        .append("line")
        .attr("class", "interval")
        //.merge(lines) // Handle update selection
        .attr("x1", (d) => xScale(d.boundary))
        .attr("x2", (d) => xScale(d.boundary))
        .attr("y1", (d) => yScale(0)) // Start at x-axis
        .attr("y2", (d) => yScale(d.endPoint)) // End at the function graph
        .attr("stroke", (d) => (d.isLastInterval ? "red" : "grey"))
        .attr("stroke-width", 2);

    // Bind boundary data for the dots
    const dots = g.selectAll("circle.interval-dot").data(boundaries);

    dots.enter()
        .append("circle")
        .attr("class", "interval-dot")
        //.merge(dots) // Handle update selection
        .attr("cx", (d) => xScale(d.boundary))
        .attr("cy", (d) => yScale(func(d.boundary))) // Dot where the line crosses the function graph
        .attr("r", 3)
        .attr("fill", (d) => (d.isLastInterval ? "red" : "black"));

    if (plotProps.displayMode === "regula-falsi") {
        const lastInterval = intervals[intervals.length - 1];
        console.log("Last interval", lastInterval);
        g.selectAll("line.secant").remove();
        g.selectAll("line.secant")
            .data([lastInterval])
            .enter()
            .append("line")
            .attr("class", "secant")
            .attr("stroke", "green")
            .attr("stroke-width", 1)
            .attr("x1", xScale(lastInterval[0]))
            .attr("x2", xScale(lastInterval[1]))
            .attr("y1", yScale(func(lastInterval[0])))
            .attr("y2", yScale(func(lastInterval[1])));
    }
}

export function repositionIntervalsOnZoom(func: (x: number) => number, plotProps: PlotPropsExt) {
    const [xScale, yScale] = [plotProps.currentXScale, plotProps.currentYScale];
    const [ymin, ymax] = yScale.domain();
    const g = selectGroupIntervals();

    g.selectAll("line.interval")
        .attr("x1", (d: any) => xScale(d.boundary))
        .attr("x2", (d: any) => xScale(d.boundary))
        .attr("y1", (d: any) => yScale(0))
        .attr("y2", (d: any) => yScale(func(d.boundary) >= 0 ? ymax : ymin));

    g.selectAll("circle.interval-dot")
        .attr("cx", (d: any) => xScale(d.boundary))
        .attr("cy", (d: any) => yScale(func(d.boundary)));

    if (plotProps.displayMode === "regula-falsi") {
        g.selectAll("line.secant")
            .attr("x1", (d: any) => xScale(d[0]))
            .attr("x2", (d: any) => xScale(d[1]))
            .attr("y1", (d: any) => yScale(func(d[0])))
            .attr("y2", (d: any) => yScale(func(d[1])));
    }
}
