
import * as d3 from 'd3';

import { Stats } from 'janstatistics';

import { sub } from './subscript';
import { CMD } from './cmd';
import { hLine, vLine, cross } from './cmd';

type Coordinate = { x: number; y: number };

// σ

// todo replace stats by locally declared interface :)


export class Plotter {

    isDataSufficient(data: Coordinate[]) {
        return data.length > 1;
    }

    showMessageWhenNoSufficientData(svg: d3.Selection<SVGSVGElement, unknown, null, undefined>) {
        // For n = 1 we cannot calculate the standard deviation
        svg.selectAll("*").remove(); // clear the plot?
        svg.append('text').attr('x', 10).attr('y', 20).text('Vizualizace funguje jen pro dva a více datových bodů.');
    }

    setupMargins() {
        const margin = { top: 20, right: 120, bottom: 60, left: 40 };
        return margin;
    }

    clearContent(svg: d3.Selection<SVGSVGElement, unknown, null, undefined>) {
        svg.selectAll("*").remove();
    }

    prepareTopology(margin: { top: number; right: number; bottom: number; left: number; }, statsX: Stats, statsY: Stats, sameScale: boolean) {

        const maxWidth = 600 - margin.left - margin.right;
        const maxHeight = 600 - margin.top - margin.bottom;

        const xDomainL = Math.min(statsX.min - statsX.range*0.1, statsX.average - statsX.std*1.1);
        const xDomainR = Math.max(statsX.max + statsX.range*0.1, statsX.average + statsX.std*1.1);
        const yDomainL = Math.min(statsY.min - statsY.range*0.1, statsY.average - statsY.std*1.1);
        const yDomainR = Math.max(statsY.max + statsX.range*0.1, statsY.average + statsY.std*1.1);

        const xDomainLength = xDomainR! - xDomainL!;
        const yDomainLength = yDomainR! - yDomainL!;

        let width = maxWidth;
        let height = maxHeight;

        if (sameScale) {
            if (yDomainLength/xDomainLength > maxHeight/maxWidth) {
                width = maxHeight * xDomainLength / yDomainLength;
            } else {
                height = maxWidth * yDomainLength / xDomainLength;
            }
        }
        let xShift = (maxWidth - width) / 2;
        let yShift = (maxHeight - height) / 2;

        return {width, height, xShift, yShift, xDomainL, xDomainR, yDomainL, yDomainR};
    }

    buildMainPlotScaleX(domainL: number, domainR: number, width: number) {
        return d3.scaleLinear()
            .domain([domainL, domainR])
            .range([0, width]);
    }

    buildMainPlotScaleY(domainL: number, domainR: number, height: number) {
        return d3.scaleLinear()
            .domain([domainL, domainR])
            .range([height, 0]);
    }

    buildMainPlotGroup(svg: d3.Selection<SVGSVGElement, unknown, null, undefined>, margin: { top: number; right: number; bottom: number; left: number; }, xShift: number, yShift: number) {
        return svg.append('g')
            .attr('transform', `translate(${margin.left + xShift},${margin.top + yShift})`);
    }

    drawScatterPlot(g: d3.Selection<SVGGElement, unknown, null, undefined>, data: Coordinate[], xScale: d3.ScaleLinear<number, number>, yScale: d3.ScaleLinear<number, number>) {
        return g.selectAll("circle")
            .data(data)
            .enter().append("circle")
            .attr("cx", d => xScale(d.x))
            .attr("cy", d => yScale(d.y))
            .attr("r", 5)
            .attr("fill", "grey");
    }

    getProjectionStyle() {
        return {
            tickLength: 16,
            tickThickness: 4,
            tickColor: 'grey'
        };
    }

    buildProjectionGroupX(svg: d3.Selection<SVGSVGElement, unknown, null, undefined>, margin: { top: number; right: number; bottom: number; left: number; }, height:number, xShift: number, yShift: number) {
        return svg.append('g')
            .attr('transform', `translate(${margin.left + xShift}, ${height + margin.top + 40 + yShift})`);
    }

    buildProjectionGroupY(svg: d3.Selection<SVGSVGElement, unknown, null, undefined>, margin: { top: number; right: number; bottom: number; left: number; }, width: number, xShift: number, yShift: number) {
        return svg.append('g')
            .attr('transform', `translate(${width + margin.left + 60 + xShift}, ${margin.top + yShift})`);
    }

    buildAxisX(g: d3.Selection<SVGGElement, unknown, null, undefined>, scale: d3.ScaleLinear<number, number>, height: number) {
        const xAxis = d3.axisBottom(scale);
        g.append("g")
          .attr("transform", `translate(0,${height})`)
          .call(xAxis);
        return xAxis; // ? is this right?
    }

    buildAxisY(g: d3.Selection<SVGGElement, unknown, null, undefined>, scale: d3.ScaleLinear<number, number>) {
        const yAxis = d3.axisLeft(scale);
        g.append("g")
          .call(yAxis);
    }

    drawPointOfAverages(g: d3.Selection<SVGGElement, unknown, null, undefined>, x: number, y: number, tickLength: number) {
        // hLineM(mainGraph.append('line'), xScale(xMean), yScale(yMean), tickLength, 'red', 2);
        // vLineM(mainGraph.append('line'), xScale(xMean), yScale(yMean), tickLength, 'red', 2);
        cross(g, x, y, tickLength, 'red', 2);
    }

    drawGridHorizontal(g: d3.Selection<SVGGElement, unknown, null, undefined>, scale: d3.ScaleLinear<number, number>, width: number, stats: Stats) {
        hLine(g, 0, width, scale(stats.average - stats.std), 'blue', 0.5);
        hLine(g, 0, width, scale(stats.average), 'red', 0.5);
        hLine(g, 0, width, scale(stats.average + stats.std), 'blue', 0.5);
    }

    drawSDLine(g: d3.Selection<SVGGElement, unknown, null, undefined>, xScale: d3.ScaleLinear<number, number>, yScale: d3.ScaleLinear<number, number>, width: number, statsX: Stats, statsY: Stats) {
        const cmd = new CMD(xScale, yScale);
        const t = 1.5
        const [xMean, xSD] = [statsX.average, statsX.std];
        const [yMean, ySD] = [statsY.average, statsY.std];
        cmd.line(g, xMean - t*xSD, yMean -  t*ySD, xMean + t*xSD, yMean + t*ySD, 'green', 2).attr('stroke-dasharray', '10,10');
    }

    drawGridVertical(g: d3.Selection<SVGGElement, unknown, null, undefined>, scale: d3.ScaleLinear<number, number>, height: number, stats: Stats) {
        vLine(g, scale(stats.average - stats.std), 0, height, 'blue', 0.5);
        vLine(g, scale(stats.average), 0, height, 'red', 0.5);
        vLine(g, scale(stats.average + stats.std), 0, height, 'blue', 0.5);
    }

    _xScale: d3.ScaleLinear<number, number> | null = null;
    _yScale: d3.ScaleLinear<number, number> | null = null;
    _mainGraphG: d3.Selection<SVGGElement, unknown, null, undefined> | null = null;

    plot(svg: d3.Selection<SVGSVGElement, unknown, null, undefined> , data: Coordinate[], statsX: Stats, statsY: Stats, sameScale: boolean) {

        if (!this.isDataSufficient(data)) {
            this.showMessageWhenNoSufficientData(svg);
            return;
        }

        // ---------------------------------------------------------------------
        // clear all
        // ---------------------------------------------------------------------

        this.clearContent(svg)

        // ---------------------------------------------------------------------
        // topology
        // ---------------------------------------------------------------------

        const margin = this.setupMargins()

        const {width, height, xShift, yShift, xDomainL, xDomainR, yDomainL, yDomainR} = this.prepareTopology(margin, statsX, statsY, sameScale);

        const xScale = this.buildMainPlotScaleX(xDomainL, xDomainR, width);
        const yScale = this.buildMainPlotScaleY(yDomainL, yDomainR, height);
        this._xScale = xScale;
        this._yScale = yScale;

        // ---------------------------------------------------------------------
        // main plot
        // ---------------------------------------------------------------------

        const mainGraph = this.buildMainPlotGroup(svg, margin, xShift, yShift);
        const points = this.drawScatterPlot(mainGraph, data, xScale, yScale);
        this._mainGraphG = mainGraph;

        // ---------------------------------------------------------------------
        // projection bands
        // ---------------------------------------------------------------------

        const {tickLength, tickThickness, tickColor} = this.getProjectionStyle();

        const gProjX = this.buildProjectionGroupX(svg, margin, height, xShift, yShift);
        this.addProjectionX(gProjX, statsX, {scale: xScale, tickLength, tickThickness, tickColor});

        const gProjY = this.buildProjectionGroupY(svg, margin, width, xShift, yShift);
        this.addProjectionY(gProjY, statsY, {scale: yScale, tickLength, tickThickness, tickColor});

        // ---------------------------------------------------------------------
        // axes
        // ---------------------------------------------------------------------

        const xAxis = this.buildAxisX(mainGraph, xScale, height);
        const yAxis = this.buildAxisY(mainGraph, yScale);

        // ---------------------------------------------------------------------
        // point of averages
        // ---------------------------------------------------------------------

        this.drawPointOfAverages(mainGraph, xScale(statsX.average), yScale(statsY.average), tickLength);

        // ---------------------------------------------------------------------
        // grid
        // ---------------------------------------------------------------------

        this.drawGridHorizontal(mainGraph, yScale, width, statsY);
        this.drawGridVertical(mainGraph, xScale, height, statsX);

        // ---------------------------------------------------------------------
        // SD line
        // ---------------------------------------------------------------------

        this.drawSDLine(mainGraph, xScale, yScale, width, statsX, statsY);
    }

    // ---------------------------------------------------------------------
    // x-projection
    // ---------------------------------------------------------------------
    addProjectionX(g: d3.Selection<SVGGElement, unknown, null, undefined>, stats: Stats, spec: IProjectionBandSpec) {

        const mu: number = stats.average
        const sd: number = stats.std; // sd
        const data: number[] = stats.orderedSample; // scores

        const {scale, tickLength, tickThickness, tickColor} = spec;

        g.selectAll('rect')
            .data(data)
            .enter().append('rect')
            .attr('x', d => scale(d) - tickThickness/2)
            .attr('y', 0)
            .attr('width', tickThickness)
            .attr('height', tickLength)
            .attr('fill', tickColor);

        const drawMarkX = (x: number, color: string) => {
            g.append('line')
                .attr('x1', scale(x))
                .attr('y1', -2)
                .attr('x2', scale(x))
                .attr('y2', 18)
                .attr('stroke', color)
                .attr('stroke-width', 0.5);
        }

        // Highlight Mean on Horizontal Bar
        console.log('mu', mu, 'sd', sd)
        drawMarkX(mu - sd, 'blue');
        drawMarkX(mu, 'red');
        drawMarkX(mu + sd, 'blue');

        // Labels for mean and standard deviation
        g.append('text')
            .attr('x', scale(mu - sd) - 14)
            .attr('y', - 5)
            .text('x̄ - σ')
            .attr("fill", "blue")
            .call(sub, 'x');

        g.append('text')
            .attr('x', scale(mu) - 4)
            .attr('y', - 5)
            .text('x̄')
            .attr('fill', 'red');

        g.append('text')
            .attr('x', scale(mu + sd) - 16)
            .attr('y', - 5)
            .text('x̄ + σ')
            .attr("fill", "blue")
            .call(sub, 'x');
    }

    // ---------------------------------------------------------------------
    // y-projection
    // ---------------------------------------------------------------------
    addProjectionY(g: d3.Selection<SVGGElement, unknown, null, undefined>, stats: Stats, spec: IProjectionBandSpec) {

        const mu: number = stats.average
        const sd: number = stats.std; // sd
        const data: number[] = stats.orderedSample; // scores

        const {scale, tickLength, tickThickness, tickColor} = spec;

        g.selectAll('rect')
            .data(data)
            .enter().append('rect')
            .attr('x', 0)
            .attr('y', d => scale(d) - tickThickness/2)
            .attr('width', tickLength)
            .attr('height', tickThickness)
            .attr('fill', tickColor);

        // Highlight Mean on Vertical Bar

        const drawMarkY = (y: number, color: string) => {
            g.append('line')
                .attr('x1', -2)
                .attr('y1', scale(y))
                .attr('x2', 20)
                .attr('y2', scale(y))
                .attr('stroke', color)
                .attr('stroke-width', 0.5);
        }

        drawMarkY(mu, 'red')
        drawMarkY(mu - sd, 'blue')
        drawMarkY(mu + sd, 'blue')

        g.append('text')
            .attr('x', -44)
            .attr('y', scale(mu - sd) + 3)
            .text('ȳ - σ')
            .attr("fill", "blue")
            .call(sub, 'y');

        g.append('text')
            .attr('x', -20)
            .attr('y', scale(mu) + 3)
            .text('ȳ')
            .attr('fill', 'red');

        g.append('text')
            .attr('x', -44)
            .attr('y', scale(mu + sd) + 3)
            .text('ȳ + σ')
            .attr("fill", "blue")
            .call(sub, 'x');
    }
}

interface IProjectionBandSpec {
    scale: d3.ScaleLinear<number, number>;
    tickLength: number;
    tickThickness: number;
    tickColor: string;
}