import React, {createRef, useMemo, useRef, useState} from "react";
import {scaleBand, scaleLinear, scaleOrdinal} from "@visx/scale";
import {Box, Divider, Grid, Paper, Typography, useTheme} from "@mui/material";
import {GridRows} from "@visx/grid";
import {Group} from "@visx/group";
import {AreaClosed, BarGroup as BarGroupVis, BarStack, Line, LinePath} from '@visx/shape';
import ResizeListener from "../../utils/ResizeListener";
import {AxisBottom, AxisLeft, AxisRight} from "@visx/axis";
import {curveMonotoneX, curveNatural} from '@visx/curve';
import {defaultStyles, useTooltip, useTooltipInPortal} from "@visx/tooltip";
import Timeline from "@mui/icons-material/Timeline";
import {Circle, StackedBarChart, StackedLineChart} from "@mui/icons-material";
import {BarChart} from "react-feather";
import {capitalizeSnakeCase} from "../../utils/string";

const vMargin = 80;
const hMargin = 60;

const BlendedVis = ({
                      height,
                      data,
                      leftAxisSelectedItems,
                      leftAxisRendition,
                      rightAxisSelectedItems,
                      rightAxisRendition,
                      axisUseCommonYScale,
                      selectedBreakdownItem,
                      x0ScaleValueAccessor,
                      xAxisLabel,
                      yAxisLabel,
                      valueFormatters,
                      ...props
}) => {
  const theme = useTheme();
  const { tooltipOpen, tooltipLeft, tooltipTop, tooltipData, hideTooltip, showTooltip } = useTooltip();
  const { TooltipInPortal } = useTooltipInPortal({
    scroll: true,
  });
  const tooltipStyles = {
    ...defaultStyles,
    minWidth: 60,
  };
  const visContainerRef = useRef(null);
  const [visWidth, setVisWidth] = useState(450);
  const hasData = data && data.length > 0;
  const [xFixture, setXFixture] = useState({});

  const handleResize = () => {
    setVisWidth(visContainerRef.current.clientWidth - 24);
  };

  const leftAxisKeys = useMemo(
    () => hasData ? Object.keys(data[0]).filter(k => leftAxisSelectedItems.includes(k)) : [],
    [data, leftAxisSelectedItems]
  );
  const rightAxisKeys = useMemo(
    () => hasData ? Object.keys(data[0]).filter(k => rightAxisSelectedItems.includes(k)) : [],
    [data, rightAxisSelectedItems]
  );
  const breakdownKeys = useMemo(
    () => getDistinctPropertyValues(data),
    [data]
  );
  const leftAxisBreakdownKeys = useMemo(
    () => {
      if (leftAxisSelectedItems.length === 0) {
        return [];
      }

      const [singleKey] = leftAxisSelectedItems;
      return breakdownKeys.length > 0 ? breakdownKeys.map(bKey => `${bKey}___${singleKey}`) : [singleKey]
    },
    [leftAxisSelectedItems, breakdownKeys]
  );
  const rightAxisBreakdownKeys = useMemo(
    () => {
      if (rightAxisSelectedItems.length === 0) {
        return [];
      }

      const [singleKey] = rightAxisSelectedItems;
      return breakdownKeys.length > 0 ? breakdownKeys.map(bKey => `${bKey}___${singleKey}`) : [singleKey]
    },
    [rightAxisSelectedItems, breakdownKeys]
  );

  const leftAxisBarRefs = Array.from({ length: data.length }, () => createRef());
  const leftAxisLineRefs = data.length > 0 && Array.from({ length: data.length * leftAxisSelectedItems.length }, () => createRef());
  const leftAxisAreaRefs = data.length > 0 && Array.from({ length: data.length * leftAxisSelectedItems.length }, () => createRef());
  const leftAxisStackedBarRefs = data.length > 0 && Array.from({ length: data.length * leftAxisBreakdownKeys.length }, () => createRef());
  const rightAxisBarRefs = Array.from({ length: data.length }, () => createRef());
  const rightAxisLineRefs = data.length > 0 && Array.from({ length: data.length * rightAxisSelectedItems.length }, () => createRef());
  const rightAxisAreaRefs = data.length > 0 && Array.from({ length: data.length * rightAxisSelectedItems.length }, () => createRef());
  const rightAxisStackedBarRefs = data.length > 0 && Array.from({ length: data.length * leftAxisBreakdownKeys.length }, () => createRef());

  // bounds
  const xMax = visWidth - hMargin;
  const yMax = height - vMargin;

  // scales
  const x0Scale = useMemo(
    () => scaleBand({
      domain: hasData ? data.map(x0ScaleValueAccessor) : [],
      padding: 0.2,
    }),
    [data]
  );
  const leftAxisX1Scale = useMemo(
    () => scaleBand({
      domain: leftAxisKeys,
      padding: 0.1,
    }),
    [leftAxisKeys]
  );
  const rightAxisX1Scale = useMemo(
    () => scaleBand({
      domain: rightAxisKeys,
      padding: 0.1,
    }),
    [rightAxisKeys]
  );
  const commonYScale = useMemo(
    () => scaleLinear({
      domain: [
        0,
        hasData ?
          Math.max(
            ...data.map(
              d => Math.max(
                ...(leftAxisRendition !== 'stacked_bar' ?
                  leftAxisKeys.map(key => Number(d[key])) :
                  [leftAxisBreakdownKeys.reduce((acc, key) => acc + Number(d[key] ?? 0), 0)]),
                ...(rightAxisRendition !== 'stacked_bar' ?
                  rightAxisKeys.map(key => Number(d[key])) :
                  [rightAxisBreakdownKeys.reduce((acc, key) => acc + Number(d[key] ?? 0), 0)]),
              )
            )
          )
          : 0
      ],
    }),
    [data, leftAxisKeys, rightAxisKeys, leftAxisRendition, rightAxisRendition, leftAxisBreakdownKeys, rightAxisBreakdownKeys]
  );

  const leftAxisYScale = useMemo(
    () => scaleLinear({
      domain: [
        0,
        hasData ?
          Math.max(
            ...data.map(
              d => Math.max(
                ...(leftAxisRendition !== 'stacked_bar' ?
                  leftAxisKeys.map(key => Number(d[key])) :
                  [leftAxisBreakdownKeys.reduce((acc, key) => acc + Number(d[key] ?? 0), 0)]),
              )
            )
          )
          : 0
      ],
    }),
    [data, leftAxisKeys, leftAxisRendition, leftAxisBreakdownKeys]
  );

  const rightAxisYScale = useMemo(
    () => scaleLinear({
      domain: [
        0,
        hasData ?
          Math.max(
            ...data.map(
              d => Math.max(
                ...(rightAxisRendition !== 'stacked_bar' ?
                  rightAxisKeys.map(key => Number(d[key])) :
                  [rightAxisBreakdownKeys.reduce((acc, key) => acc + Number(d[key]), 0)]),
              )
            )
          )
          : 0
      ],
    }),
    [data, rightAxisKeys, rightAxisRendition, rightAxisBreakdownKeys]
  );

  const leftAxisColorScale = scaleOrdinal({
    domain: leftAxisKeys,
    range: Object.values(theme.palette.vis.blendedVis.leftAxis),
  });
  const leftAxisTransparentColorScale = scaleOrdinal({
    domain: leftAxisKeys,
    range: Object.values(theme.palette.vis.blendedVis.leftAxisTransparent),
  });

  const rightAxisColorScale = scaleOrdinal({
    domain: rightAxisKeys,
    range: Object.values(theme.palette.vis.blendedVis.rightAxis),
  });
  const rightAxisTransparentColorScale = scaleOrdinal({
    domain: rightAxisKeys,
    range: Object.values(theme.palette.vis.blendedVis.rightAxisTransparent),
  });

  // update scale output dimensions
  x0Scale.rangeRound([0, (axisUseCommonYScale ? xMax : xMax - hMargin)]);
  leftAxisX1Scale.rangeRound([0, x0Scale.bandwidth()]);
  rightAxisX1Scale.rangeRound([0, x0Scale.bandwidth()]);
  commonYScale.range([yMax, 0]);
  leftAxisYScale.range([yMax, 0]);
  rightAxisYScale.range([yMax, 0]);

  // calculate half a horizontal step width to offset the curves
  const curveOffset = (hasData ? x0Scale(x0ScaleValueAccessor(data[1])) - x0Scale(x0ScaleValueAccessor(data[0])) : 0) * 0.3846;

  const visRelativeX = x => x - visContainerRef.current.getBoundingClientRect().left - 16 - hMargin;

  const findClosestXPointTo = x => {
    if (data.length < 1) {
      return 0;
    }

    const relativeX = visRelativeX(x);

    let closestX = x0Scale(x0ScaleValueAccessor(data[0])) + curveOffset;
    let closestXKey = x0ScaleValueAccessor(data[0]);

    if (data.length === 1) {
      return closestX;
    }

    for (let index = 1; index < data.length; index++) {
      const currentX = x0Scale(x0ScaleValueAccessor(data[index])) + curveOffset;
      const currentXKey = x0ScaleValueAccessor(data[index]);

      if (Math.abs(relativeX - currentX) < Math.abs(relativeX - closestX)) {
        closestX = currentX;
        closestXKey = currentXKey;
      }
    }

    setXFixture({
      key: closestXKey,
      value: closestX
    });

    return closestX;
  };

  const findVisItemData = (refs, x) => {
    let foundData = [];

    refs.forEach((ref, index) => {
      const element = ref.current;

      if (element) {
        const rect = element.getBoundingClientRect();

        if (
          x >= (visRelativeX(rect.left) - 2) &&
          x <= (visRelativeX(rect.right) + 2)
        ) {
          foundData = foundData.concat(JSON.parse(element.getAttribute('data-value') ?? ''));
        } else if ( x > visRelativeX(rect.right) ) {
          return;
        }
      }
    });

    return foundData;
  };

  const calculateTooltipData = (cursorX, cursorY) => {
    let leftAxisData;
    let rightAxisData;

    const closestXPoint = findClosestXPointTo(cursorX);

    if (leftAxisRendition === 'bar') {
      leftAxisData = findVisItemData(leftAxisBarRefs, closestXPoint);
    }

    if (leftAxisRendition === 'line') {
      leftAxisData = findVisItemData(leftAxisLineRefs, closestXPoint);
    }

    if (leftAxisRendition === 'area') {
      leftAxisData = findVisItemData(leftAxisAreaRefs, closestXPoint);
    }

    if (leftAxisRendition === 'stacked_bar') {
      leftAxisData = findVisItemData(leftAxisStackedBarRefs, closestXPoint);
    }

    if (rightAxisRendition === 'bar') {
      rightAxisData = findVisItemData(rightAxisBarRefs, closestXPoint);
    }

    if (rightAxisRendition === 'line') {
      rightAxisData = findVisItemData(rightAxisLineRefs, closestXPoint);
    }

    if (rightAxisRendition === 'area') {
      rightAxisData = findVisItemData(rightAxisAreaRefs, closestXPoint);
    }

    if (rightAxisRendition === 'stacked_bar') {
      rightAxisData = findVisItemData(rightAxisStackedBarRefs, closestXPoint);
    }

    return {
      closestXPoint: closestXPoint,
      leftAxisData: leftAxisData,
      rightAxisData: rightAxisData
    };
  };

  function getDistinctPropertyValues(objects) {
    if (!objects || Object.values(objects).length === 0) {
      return [];
    }
    const prefixes = new Set();

    objects.forEach(obj => {
      Object.keys(obj).forEach(key => {
        if (key.includes('__')) {
          const prefix = key.split('__')[0];
          prefixes.add(prefix);
        }
      });
    });
    return Array.from(prefixes);
  }

  const AreaRendition = ({ keys, yScale, colorScale, refsArray }) => (
    <Group left={curveOffset}>
      {keys.map((key, kIndex) => (
        <Group key={`group_area__${key}`}>
          {data.map((d, dIndex ) => (
            <circle
              key={`area_point__${key}__${x0ScaleValueAccessor(d)}`}
              ref={refsArray[dIndex * keys.length + kIndex]}
              data-value={JSON.stringify([{ key: key, value: d[key], color: colorScale(key) }])}
              r={3}
              cx={x0Scale(x0ScaleValueAccessor(d))}
              cy={yScale(d[key])}
              stroke="rgba(33,33,33,0.5)"
              fill="transparent"
            />
          ))}

          <AreaClosed
            data={data}
            x={(d) => x0Scale(x0ScaleValueAccessor(d))}
            y={(d) => yScale(d[key])}
            yScale={yScale}
            strokeWidth={1}
            stroke={colorScale(key)}
            fill={colorScale(key)}
            curve={curveMonotoneX}
          />
        </Group>
      ))}
    </Group>
  );

  const LineRendition = ({ keys, yScale, colorScale, refsArray }) => (
    <Group left={curveOffset}>
      {keys.map((key, kIndex) => (
        <Group key={`group_line__${key}`}>
          {data.map((d, dIndex) => (
            <circle
              key={`line_point__${key}__${x0ScaleValueAccessor(d)}`}
              ref={refsArray[dIndex * keys.length + kIndex]}
              data-value={JSON.stringify([{ key: key, value: d[key], color: colorScale(key) }])}
              r={3}
              cx={x0Scale(x0ScaleValueAccessor(d))}
              cy={yScale(d[key])}
              stroke={colorScale(key)}
              fill="transparent"
            />
          ))}

          <LinePath
            key={`line__${key}`}
            curve={curveNatural}
            data={data}
            x={(d) => x0Scale(x0ScaleValueAccessor(d))}
            y={(d) => yScale(d[key])}
            stroke={colorScale(key)}
            strokeWidth={2}
            strokeOpacity={1}
            shapeRendering="geometricPrecision"
          />
        </Group>
      ))}
    </Group>
  );

  const BarGroupRendition = ({ keys, x1Scale, yScale, colorScale, refsArray }) => (
    <BarGroupVis
      data={data}
      keys={keys}
      height={yMax}
      x0={x0ScaleValueAccessor}
      x0Scale={x0Scale}
      x1Scale={x1Scale}
      yScale={yScale}
      color={colorScale}
    >
      {(barGroups) =>
        barGroups.map((barGroup, barGroupIndex) => (
          <g
            key={`bar-group-g-${barGroup.index}-${barGroup.x0}`}
            data-value={JSON.stringify(barGroup.bars.map(bar => ({ key: bar.key, value: bar.value, color: bar.color })))}
            ref={refsArray[barGroupIndex]}
          >
            <Group
              key={`bar-group-${barGroup.index}-${barGroup.x0}`}
              left={barGroup.x0}
            >
              {barGroup.bars.map((bar) => (
                <rect
                  key={`bar-group-bar-${barGroup.index}-${bar.index}-${bar.value}-${bar.key}`}
                  x={bar.x}
                  y={bar.y}
                  width={bar.width}
                  height={bar.height}
                  fill={bar.color}
                  rx={4}
                />
              ))}
            </Group>
          </g>
        ))
      }
    </BarGroupVis>
  );

  const StackedBarRendition = ({ keys, breakdownKeys, yScale, colorScale, refsArray }) => {
    const singleKey = [keys];
    let finalKeys = breakdownKeys.length > 0 ? breakdownKeys.map(bKey => `${bKey}___${singleKey}`) : [singleKey];

    return keys.length === 0 ? null : (
      <BarStack
        data={data}
        keys={finalKeys}
        x={x0ScaleValueAccessor}
        xScale={x0Scale}
        yScale={yScale}
        color={colorScale}
      >
        {barStacks =>
          barStacks.map((barStack, barStackIndex) => (
            <g
              key={`bar-stack-${barStack.index}`}
            >
              {barStack.bars.map((bar, barIndex) => (
                <rect
                  key={`bar-stack-${barStack.index}-bar-${barIndex}`}
                  data-value={JSON.stringify({key: bar.key, value: bar.bar[1] - bar.bar[0], color: bar.color})}
                  ref={refsArray[barIndex * finalKeys.length + barStackIndex]}
                  x={bar.x}
                  y={bar.y}
                  height={bar.height}
                  width={bar.width}
                  fill={bar.color}
                  rx={4}
                />
              ))}
            </g>
          ))
        }
      </BarStack>
    );
  };

  return visWidth < 10 || !hasData ? null : (
    <Paper sx={{
      p: '16px',
      width: "100%",
      display: "flex",
      flexDirection: "column",
      alignItems: "center",
      position: 'relative',
      '&:hover .download-icon': {visibility: 'visible'}
    }}
           ref={visContainerRef}>
      <svg width={visWidth} height={height}
           onMouseLeave={() => {
             hideTooltip();
           }}
           onMouseMove={(event) => {
             showTooltip({
               tooltipData: calculateTooltipData(event.clientX, event.clientY),
               tooltipTop: event.clientY,
               tooltipLeft: event.clientX,
             });
           }}>
        <Group
          top={vMargin / 6}
          left={hMargin}>

          {axisUseCommonYScale && (
            <GridRows
              scale={commonYScale}
              width={xMax}
              height={yMax}
              stroke={theme.palette.text.tertiary}
            />
          )}

          {tooltipOpen && (
            <Line
              from={{ x: xFixture.value, y: 0 }}
              to={{ x: xFixture.value, y: yMax }}
              stroke={theme.palette.text.primary}
              strokeWidth={2}
              pointerEvents="none"
              strokeDasharray="5,2"
            />
          )}

          {leftAxisRendition === 'stacked_bar' && (
            <StackedBarRendition
              keys={leftAxisSelectedItems}
              breakdownKeys={breakdownKeys}
              yScale={axisUseCommonYScale ? commonYScale : leftAxisYScale}
              colorScale={leftAxisColorScale}
              refsArray={leftAxisStackedBarRefs}
            />
          )}

          {rightAxisRendition === 'stacked_bar' && (
            <StackedBarRendition
              keys={rightAxisSelectedItems}
              breakdownKeys={breakdownKeys}
              yScale={axisUseCommonYScale ? commonYScale : rightAxisYScale}
              colorScale={rightAxisColorScale}
              refsArray={rightAxisStackedBarRefs}
            />
          )}

          {leftAxisRendition === 'bar' && (
            <BarGroupRendition
              keys={leftAxisKeys}
              x1Scale={leftAxisX1Scale}
              yScale={axisUseCommonYScale ? commonYScale : leftAxisYScale}
              colorScale={leftAxisColorScale}
              refsArray={leftAxisBarRefs}
            />
          )}

          {rightAxisRendition === 'bar' && (
            <BarGroupRendition
              keys={rightAxisKeys}
              x1Scale={rightAxisX1Scale}
              yScale={axisUseCommonYScale ? commonYScale : rightAxisYScale}
              colorScale={rightAxisColorScale}
              refsArray={rightAxisBarRefs}
            />
          )}

          {leftAxisRendition === 'line' && (
            <LineRendition
              keys={leftAxisKeys}
              yScale={axisUseCommonYScale ? commonYScale : leftAxisYScale}
              colorScale={leftAxisColorScale}
              refsArray={leftAxisLineRefs}
            />
          )}

          {rightAxisRendition === 'line' && (
            <LineRendition
              keys={rightAxisKeys}
              yScale={axisUseCommonYScale ? commonYScale : rightAxisYScale}
              colorScale={rightAxisColorScale}
              refsArray={rightAxisLineRefs}
            />
          )}

          {leftAxisRendition === 'area' && (
            <AreaRendition
              keys={leftAxisKeys}
              yScale={axisUseCommonYScale ? commonYScale : leftAxisYScale}
              colorScale={leftAxisTransparentColorScale}
              refsArray={leftAxisAreaRefs}
            />
          )}

          {rightAxisRendition === 'area' && (
            <AreaRendition
              keys={rightAxisKeys}
              yScale={axisUseCommonYScale ? commonYScale : rightAxisYScale}
              colorScale={rightAxisTransparentColorScale}
              refsArray={rightAxisAreaRefs}
            />
          )}

          {tooltipOpen && tooltipData && (
            <TooltipInPortal top={tooltipTop} left={tooltipLeft} style={tooltipStyles}>
              <Box sx={{ width: '100%', display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
                <Typography variant="h5">{ xFixture.key }</Typography>
              </Box>

              <Divider sx={{ my: 2 }} />

              <Grid container sx={{ p: 2, minWidth: '400px' }}>
                <Grid item xs={6} sx={{ px: 6 }}>
                  <Box sx={{ display: 'flex', alignItems: 'start', color: 'tertiary.main', mb: 4 }}>
                    {leftAxisRendition === 'bar' && (<BarChart sx={{ mr: 2 }} />)}
                    {leftAxisRendition === 'line' && (<Timeline sx={{ mr: 2 }} />)}
                    {leftAxisRendition === 'area' && (<StackedLineChart sx={{ mr: 2 }} />)}
                    {leftAxisRendition === 'stacked_bar' && (<StackedBarChart sx={{ mr: 2 }} />)}

                    <Box>
                      <Typography variant="h6">Left axis</Typography>
                      <Typography variant="overline" sx={{ lineHeight: '10px', color: 'text.secondary' }}>{capitalizeSnakeCase(leftAxisRendition)}</Typography>
                    </Box>
                  </Box>

                  {tooltipData.leftAxisData && tooltipData.leftAxisData.map(datum => (
                    <Box key={`tooltip__box__left_axis__${datum.key}`} sx={{ width: '100%', display: 'flex', justifyContent: 'space-between', alignItems: 'center', mb: 1 }}>
                      <Box sx={{ display: 'flex', alignItems: 'center' }}>
                        <Circle sx={{ color: datum.color, mr: 1 }}/>

                        <Typography variant="body2">{capitalizeSnakeCase(datum.key)}</Typography>
                      </Box>

                      <Typography variant="body2"><b>{valueFormatters && valueFormatters.hasOwnProperty(datum.key) ? valueFormatters[datum.key](datum.value) : datum.value}</b></Typography>
                    </Box>
                  ))}
                </Grid>

                <Grid item xs={6} sx={{ px: 6 }}>
                  <Box sx={{ display: 'flex', alignItems: 'start', color: 'tertiary.main', mb: 4 }}>
                    {rightAxisRendition === 'bar' && (<BarChart sx={{ mr: 2 }} />)}
                    {rightAxisRendition === 'line' && (<Timeline sx={{ mr: 2 }} />)}
                    {rightAxisRendition === 'area' && (<StackedLineChart sx={{ mr: 2 }} />)}
                    {rightAxisRendition === 'stacked_bar' && (<StackedBarChart sx={{ mr: 2 }} />)}

                    <Box>
                      <Typography variant="h6">Right axis</Typography>
                      <Typography variant="overline" sx={{ lineHeight: '10px', color: 'text.secondary' }}>{capitalizeSnakeCase(rightAxisRendition)}</Typography>
                    </Box>
                  </Box>

                  {tooltipData.rightAxisData && tooltipData.rightAxisData.map(datum => (
                    <Box key={`tooltip__box__right_axis__${datum.key}`} sx={{ width: '100%', display: 'flex', justifyContent: 'space-between' }}>
                      <Box sx={{ display: 'flex', alignItems: 'center', mr: 4 }}>
                        <Circle sx={{ color: datum.color, mr: 1 }}/>

                        <Typography variant="body2" sx={{ whiteSpace: 'no-wrap' }}>{capitalizeSnakeCase(datum.key)}</Typography>
                      </Box>

                      <Typography variant="body2"><b>{valueFormatters && valueFormatters.hasOwnProperty(datum.key) ? valueFormatters[datum.key](datum.value) : datum.value}</b></Typography>
                    </Box>
                  ))}
                </Grid>
              </Grid>
            </TooltipInPortal>
          )}
        </Group>

        <AxisLeft
          scale={axisUseCommonYScale ? commonYScale : leftAxisYScale}
          left={hMargin}
          top={vMargin / 6}
          labelProps={{ fill: theme.palette.text.primary }}
          tickLabelProps={{ fill: theme.palette.text.primary }}
          hideTicks={true}
          hideAxisLine={true}
          label={yAxisLabel}
        />

        {!axisUseCommonYScale && (
          <AxisRight
            scale={rightAxisYScale}
            left={xMax}
            top={vMargin / 6}
            labelProps={{ fill: theme.palette.text.primary }}
            tickLabelProps={{ fill: theme.palette.text.primary }}
            hideTicks={true}
            hideAxisLine={true} />
        )}

        <AxisBottom
          left={hMargin}
          top={yMax + vMargin / 2}
          labelProps={{ fill: theme.palette.text.primary }}
          tickLabelProps={{ fill: theme.palette.text.primary }}
          scale={x0Scale}
          hideTicks={true}
          hideAxisLine={true}
          label={xAxisLabel}
        />
      </svg>

      <ResizeListener onResize={handleResize}/>
    </Paper>
  );
};

export default BlendedVis;