import "./column-resizer.scss";
import React, {RefObject, useCallback, useEffect, useMemo, useRef, useState} from "react";
import {clamp, Dict, tryParseJson} from "./util";
import {DraggableCore} from "react-draggable";

const MIN_COL_WIDTH = 20;
const BASE_COL_WEIGHT = 100;

export interface ColumnResizeOptions {
    tableId?: string;
    relative?: boolean;
}

export interface ColumnResizeInstance {
    columnWidths: Dict<string>;
    renderResizeHandle: (columnId: string) => React.ReactNode;
}

export function useColumnResize(ref: RefObject<HTMLTableElement>, columnIds: string[], options?: ColumnResizeOptions): ColumnResizeInstance {
    const [columnWeights, setColumnWeights] = useState(getInitialColumnWeights(columnIds, options?.tableId));
    const totalWeight = Object.values(columnWeights).reduce((c, p) => c + p, 0);

    // reset column weights whenever column ids or table id change
    useEffect(() => setColumnWeights(getInitialColumnWeights(columnIds, options?.tableId)), [columnIds, options?.tableId]);

    const tableWidth = ref.current?.getBoundingClientRect().width;
    const columnWidths: Dict<string> = useMemo(() => {
        const widths: Dict<string> = {};
        if (options?.relative) {
            columnIds.forEach(id => widths[id] = (100 * columnWeights[id] / totalWeight) + "%");
        } else {
            if (!tableWidth) return {};
            columnIds.forEach(id => widths[id] = (columnWeights[id] * tableWidth / totalWeight) + "px");
        }
        return widths;
    }, [columnWeights, tableWidth]);

    const onResize = useCallback((newColumnWidths: Dict<number>) => {
        if (!tableWidth) return;
        const newColumnWeights: Dict<number> = {};
        columnIds.forEach(id => newColumnWeights[id] = newColumnWidths[id] * totalWeight / tableWidth);
        setColumnWeights(newColumnWeights);

        if (options?.tableId) {
            saveColumnWeights(newColumnWeights, options.tableId);
        }
    }, [columnIds, totalWeight, tableWidth]);

    const renderResizeHandle = (id: string) => tableWidth ? <ResizeHandle columnId={id} columnIds={columnIds} onResize={onResize}/> : undefined;

    return {columnWidths, renderResizeHandle};
}

interface ResizeHandleProps {
    columnId: string;
    columnIds: string[];
    onResize: (newWidths: Dict<number>) => void;
}

type ColumnResize = {element: HTMLTableCellElement, originalWidthStyle: string, originalWidth: number, currentWidth: number};
interface ResizeState {
    startX: number;
    resizedIndex: number;
    columns: Dict<ColumnResize>;
}

function ResizeHandle(props: ResizeHandleProps) {
    const resizeRef = useRef<ResizeState | null>(null);

    const ref = useRef<HTMLDivElement>(null);
    const getTableElement = () => ref.current?.parentElement?.parentElement as HTMLTableElement | undefined;

    const onResizeStart = useCallback((x: number | undefined) => {
        const tableElement = getTableElement();
        if (!tableElement || x == null) return false;

        const columns: Dict<ColumnResize> = {};
        const elements = Array.from(tableElement.children) as HTMLTableCellElement[];
        if (elements.length !== props.columnIds.length) return false;

        for (let i = 0; i < props.columnIds.length; i++) {
            const colId = props.columnIds[i];
            const element = elements[i];
            const width = element.getBoundingClientRect().width;
            columns[colId] = {
                element: element,
                originalWidthStyle: element.style.width,
                originalWidth: width,
                currentWidth: width
            };
        }

        resizeRef.current = {startX: x, resizedIndex: props.columnIds.findIndex(i => i === props.columnId), columns: columns};
    }, [props.columnId, props.columnIds]);

    const onResize = useCallback((x: number) => {
        if (!resizeRef.current) return;
        const resize = resizeRef.current!;
        const resizeX = Math.floor(x - resize.startX);

        if (resizeX > 0) {
            // calculate how much width we can take from columns to the right
            const rightColsIds = props.columnIds.slice(resize.resizedIndex + 1);
            const rightColsAvailableWidth = rightColsIds.map(id => resize.columns[id].originalWidth - MIN_COL_WIDTH);
            const totalAvailableWidth = rightColsAvailableWidth.reduce((c, p) => c + p, 0);
            const widthToTake = clamp(resizeX, 0, totalAvailableWidth);
            const takeRatio = widthToTake / totalAvailableWidth;

            // take width from right columns
            rightColsIds
                .map((id, i) => MIN_COL_WIDTH + rightColsAvailableWidth[i] * (1 - takeRatio))
                .forEach((newWidth, index) => resize.columns[rightColsIds[index]].currentWidth = newWidth);

            // apply new width to resized col
            resize.columns[props.columnId].currentWidth = resize.columns[props.columnId].originalWidth + totalAvailableWidth * takeRatio;
        } else {
            const leftColsIds = props.columnIds.slice(0, resize.resizedIndex + 1);
            const leftColsAvailableWidth = leftColsIds.map(id => resize.columns[id].originalWidth - MIN_COL_WIDTH);

            // take all desired, available width, column by column
            let takenWidth = 0;
            for (let i = leftColsAvailableWidth.length - 1; i >= 0; i--) {
                const widthToTake = clamp(-resizeX - takenWidth, 0, leftColsAvailableWidth[i]);
                resize.columns[props.columnIds[i]].currentWidth = resize.columns[props.columnIds[i]].originalWidth - widthToTake;
                takenWidth += widthToTake;
            }

            // distribute taken width to cols to the right
            const rightColsWidths = props.columnIds.slice(resize.resizedIndex + 1).map(id => resize.columns[id].originalWidth);
            const rightColsTotalWidth = rightColsWidths.reduce((c, p) => c + p, 0);
            for (let i = resize.resizedIndex + 1; i < props.columnIds.length; i++) {
                const colId = props.columnIds[i];
                resize.columns[colId].currentWidth = resize.columns[colId].originalWidth + takenWidth * rightColsWidths[i - resize.resizedIndex - 1] / rightColsTotalWidth;
            }
        }

        Object.values(resize.columns).forEach(c => c.element.style.width = c.currentWidth + "px");
    }, [props.columnId, props.columnIds]);

    const onResizeEnd = useCallback((x: number) => {
        if (!resizeRef.current) return;
        onResize(x);

        // get new widths
        const newWidths: Dict<number> = {};
        for (let i = 0; i < props.columnIds.length; i++) {
            const colId = props.columnIds[i];
            newWidths[colId] = resizeRef.current.columns[colId].currentWidth;
        }

        // reset width styles
        Object.values(resizeRef.current.columns).forEach(c => c.element.style.width = c.originalWidthStyle);

        resizeRef.current = null;
        props.onResize(newWidths);
    }, [props.onResize, props.columnId, props.columnIds]);

    return <DraggableCore onStart={e => onResizeStart((e as MouseEvent).clientX)}
                          onDrag={e => onResize((e as MouseEvent).clientX)}
                          onStop={e => onResizeEnd((e as MouseEvent).clientX)}
                          offsetParent={getTableElement()}>
        <div ref={ref} className="ResizeHandle" onClick={e => e.stopPropagation()}/>
    </DraggableCore>;
}

function getInitialColumnWeights(columnIds: string[], tableId?: string): Dict<number> {
    if (tableId) {
        const key = "column-resizer." + tableId;
        const saved = localStorage.getItem(key);
        const columnWeights = tryParseJson<Dict<number> | undefined>(saved, undefined);
        if (columnWeights) {
            Object.keys(columnWeights).filter(id => !columnIds.includes(id)).forEach(id => delete columnWeights[id]);
            columnIds.filter(id => !columnWeights[id]).forEach(id => columnWeights[id] = BASE_COL_WEIGHT);
            return columnWeights;
        }
    }

    const columnWeights: Dict<number> = {};
    columnIds.forEach(id => columnWeights[id] = BASE_COL_WEIGHT);
    return columnWeights;
}

function saveColumnWeights(columnWeights: Dict<number>, tableId: string) {
    localStorage.setItem("column-resizer." + tableId, JSON.stringify(columnWeights));
}