import * as d3 from "d3";

export const CLS_MAIN = "main";
export const CLS_AXES = "axes";
export const CLS_LINES = "lines";
export const CLS_POINTS = "points";

export function addGroupForLines(contentGroup: d3.Selection<SVGGElement, unknown, null, any>) {
    contentGroup.append("g").attr("class", CLS_LINES);
}

export function addGroupForPoints(contentGroup: d3.Selection<SVGGElement, unknown, null, any>) {
    contentGroup.append("g").attr("class", CLS_POINTS);
}

export function addGroupForAxes(contentGroup: d3.Selection<SVGGElement, unknown, null, any>) {
    contentGroup.append("g").attr("class", CLS_AXES);
    console.log("Add group for axes", contentGroup);
}

export function selectGroupLines(): d3.Selection<SVGGElement, unknown, HTMLElement, any> {
    return d3.select(`g.${CLS_LINES}`);
}

export function selectGroupMain(): d3.Selection<SVGGElement, unknown, HTMLElement, any> {
    return d3.select(`g.${CLS_MAIN}`);
}

export function selectGroupAxes(): d3.Selection<SVGGElement, unknown, HTMLElement, any> {
    return d3.select(`g.${CLS_AXES}`);
}

export function selectPoints(points: number[][]) {
    return d3.select(`g.${CLS_POINTS}`).selectAll("circle").data(points);
}

export interface PlotProps {
    width: number;
    height: number;
    margin: number;
    xScale: d3.ScaleLinear<number, number>;
    yScale: d3.ScaleLinear<number, number>;
    currentXScale: d3.ScaleLinear<number, number>;
    currentYScale: d3.ScaleLinear<number, number>;
}

export function defaultPlotProps(): PlotProps {
    // Set up the axes and scales
    const width = 500;
    const height = 500;
    const margin = 40;

    const xScale = d3
        .scaleLinear()
        .domain([-10, 10])
        .range([margin, width - margin]);
    //.range([0, width]);
    const yScale = d3
        .scaleLinear()
        .domain([-10, 10])
        .range([height - margin, margin]);
    //.range([height, 0]);

    return {
        width: width,
        height: height,
        margin: margin,
        xScale: xScale,
        yScale: yScale,
        currentXScale: xScale,
        currentYScale: yScale,
    };
}

// Reset function to reset zoom without needing access to `zoom`
export function resetZoomWithoutAccess(svgElement: SVGSVGElement | null) {
    if (!svgElement) return;

    const svg = d3.select(svgElement);

    // Access the current zoom transform from the SVG
    // const currentTransform = d3.zoomTransform(svg.node() as any);

    // Reset the zoom transform by setting it to identity
    svg.transition()
        .duration(1)
        .call(
            d3.zoom().transform as any, // Create a new zoom behavior here
            d3.zoomIdentity,
        );
}

export function initPlot(
    svgElement: SVGSVGElement | null,
    A: number[][],
    b: number[],
    x0: number[],
    plotProps: PlotProps,
) {
    if (!svgElement) return;
    console.log("Init plot");
    // Clear the previous graph
    const svg = d3.select(svgElement);
    svg.selectAll("*").remove();

    // Top-level content group
    const mainGroup = svg.append("g").attr("class", CLS_MAIN);
    addGroupForAxes(mainGroup);
    addGroupForLines(mainGroup);
    addGroupForPoints(mainGroup);

    // Create the zoom behavior
    const zoom = d3.zoom().on("zoom", (event) => {
        // Update the scales' domains based on the zoom event
        const newXScale = event.transform.rescaleX(plotProps.xScale); // maybe cumulative?
        const newYScale = event.transform.rescaleY(plotProps.yScale);

        plotProps.currentXScale = newXScale;
        plotProps.currentYScale = newYScale;

        console.log("zoom");
        console.log(plotProps.xScale.domain(), plotProps.xScale.range());
        console.log(plotProps.yScale.domain(), plotProps.yScale.range());
        console.log("---zoom");

        // Update the graph using the new scales
        zoomGraph(
            {
                width: plotProps.width,
                height: plotProps.height,
                margin: plotProps.margin,
                xScale: newXScale,
                yScale: newYScale,
                currentXScale: newXScale,
                currentYScale: newYScale,
            },
            A,
            b,
        );
    });

    svg.call(zoom as any);

    console.log("init");
    console.log(plotProps.xScale.domain(), plotProps.xScale.range());
    console.log(plotProps.yScale.domain(), plotProps.yScale.range());
    console.log("---init");
    redrawGraph(svgElement, plotProps, A, b, [x0]);
}

// Function to redraw the graph with updated scales
export function redrawGraph(
    svgElement: SVGSVGElement,
    plotProps: PlotProps,
    A: number[][],
    b: number[],
    points: number[][],
) {
    plotAxes(plotProps);
    plotLines(A, b, plotProps);
    plotPoints(points, plotProps);
}

// Function to redraw the graph with updated scales
export function zoomGraph(plotProps: PlotProps, A: number[][], b: number[]) {
    plotAxes(plotProps);
    plotLines(A, b, plotProps);
    repositionPointsOnZoom(plotProps);
}

export function plotPoints(points: number[][], plotProps: PlotProps) {
    const selector = `g.${CLS_POINTS}`;

    const [xScale, yScale] = [plotProps.currentXScale, plotProps.currentYScale];

    d3.select(selector).selectAll("circle").remove();
    d3.select(selector).selectAll("text").remove();

    d3.select(selector)
        .selectAll("circle")
        .data(points)
        .enter()
        .append("circle")
        .attr("id", (p) => `id_${points.indexOf(p)}`) // to find it?
        .attr("cx", (p) => xScale(p[0]))
        .attr("cy", (p) => yScale(p[1]))
        .attr("r", 5)
        .attr("fill", (p) => (points.indexOf(p) === points.length - 1 ? "green" : "black"));

    d3.select(selector)
        .selectAll("text")
        .data(points)
        .enter()
        .append("text")
        .attr("x", (p) => xScale(p[0]) + 5)
        .attr("y", (p) => yScale(p[1]) - 5)
        //.text((p) => `(${p[0].toFixed(2)}, ${p[1].toFixed(2)})`)
        .text((p) => `(${points.indexOf(p)})`)
        .attr("font-size", "12px")
        .attr("fill", "black");
}

export function repositionPointsOnZoom(plotProps: PlotProps) {
    const selector = `g.${CLS_POINTS}`;

    const [xScale, yScale] = [plotProps.currentXScale, plotProps.currentYScale];

    // Select existing circles and update their positions using the new scales
    d3.select(selector)
        .selectAll("circle")
        .attr("cx", (p: any) => xScale(p[0]))
        .attr("cy", (p: any) => yScale(p[1]));

    // Optionally, if labels exist, reposition them as well
    d3.select(selector)
        .selectAll("text")
        .attr("x", (p: any) => xScale(p[0]) + 5)
        .attr("y", (p: any) => yScale(p[1]) - 5);
}

//
// Plot the axes
//
export function plotAxes(plotProps: PlotProps) {
    const g = selectGroupAxes();
    //g.selectAll("*").remove();
    //const { xScale, yScale } = plotProps;
    const [xScale, yScale] = [plotProps.currentXScale, plotProps.currentYScale];

    let gX = g.select("g.x-axis") as any as d3.Selection<SVGGElement, unknown, HTMLElement, unknown>;
    if (gX.empty()) {
        gX = g.append("g").attr("class", "x-axis");
    }
    gX.attr("transform", `translate(0, ${yScale(0)})`);
    gX.call(d3.axisBottom(xScale));

    let gY = g.selectAll("g.y-axis") as any as d3.Selection<SVGGElement, unknown, HTMLElement, unknown>;
    if (gY.empty()) {
        gY = g.append("g").attr("class", "y-axis");
    }
    gY.attr("transform", `translate(${xScale(0)}, 0)`);
    gY.call(d3.axisLeft(yScale));

    console.log("Plot axes");
}

//
// Plot the two lines of the equations
//
export function plotLines(A: number[][], b: number[], plotProps: PlotProps) {
    const g = selectGroupLines();
    g.selectAll("*").remove();

    //const { xScale, yScale } = plotProps;
    const [xScale, yScale] = [plotProps.currentXScale, plotProps.currentYScale];

    const [xmin, xmax] = xScale.domain();
    // const [ymin, ymax] = yScale.domain();

    // Helper function to plot a line based on equation: a1 * x + a2 * y = b
    const plotLine = (a1: number, a2: number, b: number, color: string) => {
        const points = [
            { x: xmin, y: (b - a1 * xmin) / a2 },
            { x: xmax, y: (b - a1 * xmax) / a2 },
        ];
        g.append("line")
            .attr("x1", xScale(points[0].x))
            .attr("y1", yScale(points[0].y))
            .attr("x2", xScale(points[1].x))
            .attr("y2", yScale(points[1].y))
            .attr("stroke", color)
            .attr("stroke-width", 2);
    };

    // Plot both lines (equations)
    plotLine(A[0][0], A[0][1], b[0], "red");
    plotLine(A[1][0], A[1][1], b[1], "blue");
}
