import {
  Cell,
  CellContext,
  ColumnDef,
  flexRender,
  functionalUpdate,
  getCoreRowModel,
  getExpandedRowModel,
  getSortedRowModel,
  Header as RTHeader,
  Row,
  SortingFnOption,
  SortingState,
  useReactTable,
} from "@tanstack/react-table";
import sprinkles from "css/sprinkles.css";
import React from "react";
import {
  CSSProperties,
  useCallback,
  useEffect,
  useMemo,
  useRef,
  useState,
} from "react";
import { useVirtual } from "react-virtual";

import { Checkbox, Icon, Tooltip } from "..";
import { IconName } from "../icon/Icon";
import * as styles from "./Table.css";

type SortDirectionType = "ASC" | "DESC";

export type { CellContext };

export interface Header<Data> {
  id: keyof Data & string;
  label: string;
  width?: number;
  minWidth?: number;
  customHeader?: React.ReactElement;
  customCellRenderer?: (rowData: Data) => React.ReactElement | string;
  // Defaults to true
  sortable?: boolean;
  sortingFn?: SortingFnOption<Data>;
  // Defaults to true
  resizable?: boolean;
}

interface Props<Data> {
  columns: Header<Data>[];
  rows: Data[];
  totalNumRows: number;
  getRowId: (data: Data) => string;
  getRowCanExpand?: (data: Row<Data>) => boolean;
  emptyState?: {
    title: string;
    subtitle?: string;
    icon?: IconName;
  };

  getRowHighlighted?: (data: Data) => boolean;
  getRowHighlightColor?: (
    data: Data
  ) => "blue" | "red" | "orange" | "green" | undefined;

  // Infinite scroll/pagination
  onLoadMoreRows?: (id?: string) => Promise<void>;
  loadingRows?: boolean;

  // Sort
  defaultSortBy?: keyof Data & string;
  defaultSortDirection?: SortDirectionType;
  manualSortDirection?: {
    sortBy: string;
    sortDirection: SortDirectionType;
  };
  handleManualSort?: (
    // Server side sorting
    sortBy: string,
    sortDirection?: SortDirectionType
  ) => void;
  /**
   * When enableSortingRemoval is true, sorting a column will cycle from ASC -> DESC -> none.
   * When enableSortingRemoval is false, sorting a column will cycle from ASC -> DESC.
   * Defaults to false.
   */
  enableSortingRemoval?: boolean;

  // Selectable rows
  checkedRowIds?: Set<string>;
  onCheckedRowsChange?: (ids: string[], checked: boolean) => void;
  selectAllChecked?: boolean;
  onSelectAll?: (checked: boolean) => void;
  getCheckboxDisabledReason?: (row: Data) => string | undefined;

  // Clickable rows
  onRowClick?: (
    data: Data,
    e: React.MouseEvent<HTMLTableRowElement, MouseEvent>
  ) => void;

  // Expandable rows
  // Use to return child rows for a given row
  getChildRows?: (data: Data) => Data[] | undefined;
  // Use to render a custom child row
  renderSubRow?: (data: Data) => React.ReactElement;

  borderedCells?: boolean;
  hideHeader?: boolean;
}

const Table = <Data,>(props: Props<Data>) => {
  const {
    loadingRows,
    rows: propRows,
    onLoadMoreRows,
    totalNumRows,
    onCheckedRowsChange,
    getCheckboxDisabledReason,
    renderSubRow,
    getChildRows,
    enableSortingRemoval = false,
    hideHeader,
  } = props;

  const [sorting, setSorting] = useState<SortingState>(
    props.defaultSortBy
      ? [
          {
            id: props.defaultSortBy,
            desc: props.defaultSortDirection === "DESC",
          },
        ]
      : []
  );

  let sortingState = sorting;
  if (props.handleManualSort) {
    sortingState = [];
    if (props.manualSortDirection) {
      sortingState = [
        {
          id: props.manualSortDirection.sortBy,
          desc: props.manualSortDirection.sortDirection === "DESC",
        },
      ];
    }
  }

  const columnData: ColumnDef<Data>[] = useMemo(() => {
    const columns = props.columns.map((col) => {
      const colData: ColumnDef<Data> = {
        accessorKey: col.id,
        header: col.customHeader ? () => col.customHeader : col.label,
        size: col.width,
        minSize: col.minWidth,
        enableSorting: col.sortable,
        sortingFn: col.sortingFn || "alphanumeric",
        enableResizing: col.resizable ?? true,
      };
      if (col.customCellRenderer) {
        const renderCell = col.customCellRenderer;
        colData.cell = ({ row }) => renderCell(row.original);
      }
      return colData;
    });

    if (renderSubRow || getChildRows) {
      columns.unshift({
        accessorKey: "",
        enableSorting: false,
        id: "expand-toggle",
        header: () => null,
        cell: ({ row }) => {
          if (!row.getCanExpand()) {
            return null;
          }
          const icon = row.getIsExpanded() ? (
            <Icon name="chevron-down" size="sm" />
          ) : (
            <Icon
              name="chevron-right"
              size="sm"
              onClick={() => {
                onLoadMoreRows && onLoadMoreRows(row.id);
              }}
            />
          );
          return (
            <div
              className={sprinkles({ cursor: "pointer" })}
              onClick={(e) => {
                // Prevent row click from firing
                e.stopPropagation();
                row.getToggleExpandedHandler()();
              }}
            >
              {icon}
            </div>
          );
        },
        size: 24,
        minSize: 24,
        maxSize: 24,
        enableResizing: false,
      });
    }

    if (onCheckedRowsChange) {
      columns.unshift({
        accessorKey: "",
        enableSorting: false,
        id: "select",
        header: () => {
          if (props.onSelectAll && props.rows.length > 0) {
            return (
              <Checkbox
                checked={Boolean(props.selectAllChecked)}
                onChange={props.onSelectAll}
              />
            );
          }
          return null;
        },
        cell: (cellContext) => {
          const checked = props.checkedRowIds?.has(cellContext.row.id);
          const disabledReason = getCheckboxDisabledReason
            ? getCheckboxDisabledReason(cellContext.row.original)
            : undefined;

          return (
            <Tooltip tooltipText={disabledReason}>
              <Checkbox
                disabled={Boolean(disabledReason)}
                checked={Boolean(checked)}
                onChange={(checked) =>
                  onCheckedRowsChange([cellContext.row.id], checked)
                }
              />
            </Tooltip>
          );
        },
        size: 26,
        minSize: 26,
        maxSize: 26,
        enableResizing: false,
      });
    }

    return columns;
  }, [
    props.columns,
    onCheckedRowsChange,
    props.checkedRowIds,
    props.selectAllChecked,
    props.onSelectAll,
    getCheckboxDisabledReason,
    renderSubRow,
    getChildRows,
    onLoadMoreRows,
    props.rows.length,
  ]);

  const handleRowCanExpand = (row: Row<Data>) => {
    if (props.getRowCanExpand) {
      return props.getRowCanExpand(row);
    }
    return Boolean(props.renderSubRow || props.getChildRows);
  };

  const tableData = useReactTable({
    data: props.rows,
    columns: columnData,
    getCoreRowModel: getCoreRowModel(),
    state: {
      sorting: sortingState,
    },
    getRowId: props.getRowId,
    manualSorting: Boolean(props.handleManualSort),
    enableSortingRemoval: enableSortingRemoval,
    columnResizeMode: "onChange",
    enableColumnResizing: true,
    onSortingChange: (sortingUpdater) => {
      const newSortVal = functionalUpdate(sortingUpdater, sorting);
      setSorting(newSortVal);
      if (props.handleManualSort) {
        if (newSortVal.length === 0) {
          props.handleManualSort("");
          return;
        }
        props.handleManualSort(
          newSortVal[0].id,
          newSortVal[0].desc ? "DESC" : "ASC"
        );
      }
    },
    getSortedRowModel: getSortedRowModel(),
    getRowCanExpand: handleRowCanExpand,
    getSubRows: props.getChildRows,
    getExpandedRowModel: getExpandedRowModel(),
  });

  /**
   * Instead of calling `column.getSize()` on every render for every header
   * and especially every data cell (very expensive),
   * we will calculate all column sizes at once at the root table level in a useMemo
   * and pass the column sizes down as CSS variables to the <table> element.
   */
  const tableState = tableData.getState();
  const columnSizeVars = useMemo(() => {
    const headers = tableData.getFlatHeaders();
    const colSizes: { [key: string]: number } = {};
    for (let i = 0; i < headers.length; i++) {
      const header = headers[i]!;
      colSizes[`--header-${header.id}-size`] = header.getSize();
      colSizes[`--col-${header.column.id}-size`] = header.column.getSize();
    }
    return colSizes;
    // Note: we include columnData here so that we update this memo when expandRow properties are changed
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [tableState.columnSizingInfo, tableState.columnSizing, columnData]);

  const tableContainerRef = useRef<HTMLDivElement>(null);

  const fetchMoreOnBottomReached = useCallback(
    (containerRefElement?: HTMLDivElement | null) => {
      if (containerRefElement && onLoadMoreRows) {
        const { scrollHeight, scrollTop, clientHeight } = containerRefElement;
        // Once the user has scrolled within 50px of the bottom of the table, fetch more data if there is any
        if (
          scrollHeight - scrollTop - clientHeight < 50 &&
          !loadingRows &&
          propRows.length < totalNumRows
        ) {
          onLoadMoreRows();
        }
      }
    },
    [onLoadMoreRows, loadingRows, propRows, totalNumRows]
  );

  // Check on mount and after a fetch to see if the table is already scrolled to the bottom and immediately needs to fetch more data
  useEffect(() => {
    fetchMoreOnBottomReached(tableContainerRef.current);
  }, [fetchMoreOnBottomReached]);

  const { rows } = tableData.getRowModel();

  const rowVirtualizer = useVirtual({
    parentRef: tableContainerRef,
    size: rows.length,
    overscan: 10,
  });
  const { virtualItems: virtualRows, totalSize } = rowVirtualizer;
  // Add padding for virtualized rows so that scrollbar reflects correct height.
  const paddingTop = virtualRows.length > 0 ? virtualRows?.[0]?.start || 0 : 0;
  const paddingBottom =
    virtualRows.length > 0
      ? totalSize - (virtualRows?.[virtualRows.length - 1]?.end || 0)
      : 0;

  const renderHeader = useCallback(
    (header: RTHeader<Data, unknown>) => {
      if (header.isPlaceholder || hideHeader) {
        return null;
      }

      return (
        <div
          className={styles.headerCell({
            sortable: header.column.getCanSort(),
          })}
          onClick={header.column.getToggleSortingHandler()}
        >
          {flexRender(header.column.columnDef.header, header.getContext())}
          {{
            asc: <Icon name="arrow-up" color="gray600" size="xs" />,
            desc: <Icon name="arrow-down" color="gray600" size="xs" />,
          }[header.column.getIsSorted().toString()] ?? null}
        </div>
      );
    },
    [hideHeader]
  );

  // Use this to inject any styles that need to be calculated.
  // For any basic or boolean logic styling, add to Table.css "cell" class.
  const getExtraCellStyle = useCallback(
    (
      row: Row<Data>,
      cell: Cell<Data, unknown>,
      cellIndex: number
    ): CSSProperties | undefined => {
      const base = {
        width: `calc(var(--col-${cell.column.id}-size) * 1px)`,
      };
      if (row.depth === 0) {
        return base;
      }
      const hasCheckbox = Boolean(onCheckedRowsChange);
      const expandToggleIndex = hasCheckbox ? 1 : 0;
      if (cellIndex === expandToggleIndex) {
        return {
          ...base,
          paddingLeft: row.depth * 24,
          overflowX: "visible",
        };
      } else if (cellIndex === expandToggleIndex + 1) {
        return {
          ...base,
          paddingLeft: row.depth * 24,
        };
      }
    },
    [onCheckedRowsChange]
  );

  return (
    <div
      className={styles.container}
      ref={tableContainerRef}
      onScroll={(e) => fetchMoreOnBottomReached(e.currentTarget)}
    >
      <table className={styles.table} style={{ ...columnSizeVars }}>
        {!hideHeader && (
          <thead className={styles.header}>
            {tableData.getHeaderGroups().map((headerGroup) => (
              <tr key={headerGroup.id}>
                {headerGroup.headers.map((header) => {
                  return (
                    <th
                      key={header.id}
                      className={styles.columnHeader}
                      style={{
                        width: `calc(var(--header-${header?.id}-size) * 1px)`,
                        maxWidth: `calc(var(--header-${header?.id}-size) * 1px)`,
                      }}
                    >
                      {renderHeader(header)}

                      {header.column.getCanResize() ? (
                        <div
                          onDoubleClick={() => header.column.resetSize()}
                          onMouseDown={header.getResizeHandler()}
                          onTouchStart={header.getResizeHandler()}
                          className={styles.headerResizer({
                            isResizing: header.column.getIsResizing(),
                          })}
                        />
                      ) : null}
                    </th>
                  );
                })}
              </tr>
            ))}
          </thead>
        )}
        <tbody>
          {props.totalNumRows === 0 && !props.loadingRows && (
            <tr>
              <td colSpan={columnData.length}>
                <div className={styles.emptyStateContainer}>
                  <div>
                    {props.emptyState?.icon && (
                      <Icon name={props.emptyState.icon} size="lg" />
                    )}
                  </div>
                  <div className={styles.emptyStateTitle}>
                    {props.emptyState?.title ?? "No rows"}
                  </div>
                  <div>{props.emptyState?.subtitle}</div>
                </div>
              </td>
            </tr>
          )}
          {paddingTop > 0 && (
            <tr>
              <td style={{ height: `${paddingTop}px` }} />
            </tr>
          )}
          {virtualRows.map((virtualRow) => {
            const row = rows[virtualRow.index];

            const isHighlighted = props.getRowHighlighted?.(row.original);
            // highlight color is blue by default
            const highlightColor =
              props.getRowHighlightColor?.(row.original) ??
              (isHighlighted ? "blue" : undefined);

            return (
              <>
                <tr
                  key={row.id}
                  className={styles.row({
                    clickable: Boolean(props.onRowClick),
                    highlighted: highlightColor,
                    // Use row index to render alternating row colors instead of css property nth-child,
                    // because the virtualization causes the rows to switch between odd and even
                    gray: virtualRow.index % 2 === 1,
                  })}
                  onClick={(e) => props.onRowClick?.(row.original, e)}
                  ref={(e) => virtualRow.measureRef(e)}
                >
                  {row.getVisibleCells().map((cell, i) => {
                    return (
                      <td
                        key={cell.id}
                        className={styles.cell({
                          border: props.borderedCells,
                        })}
                        style={getExtraCellStyle(row, cell, i)}
                      >
                        {flexRender(
                          cell.column.columnDef.cell,
                          cell.getContext()
                        )}
                      </td>
                    );
                  })}
                </tr>
                {renderSubRow && row.getIsExpanded() && (
                  <tr className={styles.row({})} key={`${row.id}-subrow`}>
                    <td colSpan={row.getVisibleCells().length}>
                      {renderSubRow(row.original)}
                    </td>
                  </tr>
                )}
              </>
            );
          })}
          {props.loadingRows && (
            <tr>
              <td className={styles.loading}>Loading...</td>
            </tr>
          )}
          {paddingBottom > 0 && (
            <tr>
              <td style={{ height: `${paddingBottom}px` }} />
            </tr>
          )}
        </tbody>
      </table>
    </div>
  );
};

export default Table;
