import React, { useEffect, useRef, useState } from "react";
import PropTypes from "prop-types";
import styles from "./styles.css";
import colors from "assets/theme/base/colors";
import XBox from "components/XBox";

import * as d3 from "d3";

import { AxisLeft } from "./AxisLeft";
import { AxisBottom } from "./AxisBottom";
import { useXplainableController } from "context";

import { select, scaleLinear } from "d3";
import rgba from "assets/theme/functions/rgba";
import { usePrevious } from "shared/models/ModelProfile/components/shared";

const MARGIN = { top: 60, right: 60, bottom: 30, left: 60 };
const MARK_SIZE = 6; // Size of the "X" mark

const Scatterplot = ({ width: defaultWidth, height, data, isViolinVisible }) => {
  const [controller] = useXplainableController();
  const { darkMode } = controller;

  const [marksVisible, setMarksVisible] = useState(true); // State to control visibility of marks and line
  const [width, setWidth] = useState(defaultWidth);
  const [brushSumstat, setBrushSumstat] = useState(null);

  const svgContainer = useRef(null);
  const svgRef = useRef();

  const boundsWidth = width - MARGIN.right - MARGIN.left;
  const boundsHeight = height - MARGIN.top - MARGIN.bottom;

  var svg = d3.select("#my_dataviz");

  const uniqueGroups = [...new Set(data.map((obj) => obj.added_date))];

  const [selection, setSelection] = useState([uniqueGroups.length || 0, 0]);
  const previousSelection = usePrevious(selection);

  const sortedGroups = () => {
    if (Math.round(selection[1]) > Math.round(selection[0])) {
      return uniqueGroups.toSorted((a, b) => new Date(a) - new Date(b));
    }

    return uniqueGroups.toSorted((a, b) => new Date(b) - new Date(a));
  };

  const slicedGroups = sortedGroups().slice(Math.round(selection[1]), Math.round(selection[0]));
  const filteredObject = data.filter((obj) => slicedGroups.includes(obj.added_date));

  useEffect(() => {
    const container = svgContainer.current;

    const handleResize = (entries) => {
      for (let entry of entries) {
        const newWidth = entry.contentRect.width;
        setWidth(newWidth);
      }
    };

    const resizeObserver = new ResizeObserver(handleResize);

    // Observe the container element
    if (container) {
      resizeObserver.observe(container);
    }

    return () => {
      // Disconnect the observer when the component unmounts
      resizeObserver.disconnect();
    };
  }, [svgContainer]);

  // Transforming data
  const processedData = filteredObject.map((d) => ({
    x: new Date(d.added_date),
    y: d.score,
    group: d.added_date,
  }));

  // Finding the extent of the date for the xScale and score for yScale
  const xExtent = d3.extent(processedData, (d) => d.x);
  const yExtent = d3.extent(processedData, (d) => d.y);

  // Scales
  const xScale = d3.scaleTime().domain(xExtent).range([0, boundsWidth]);
  const yScale = d3.scaleLinear().domain(yExtent).range([boundsHeight, 0]);

  var x = d3.scaleBand().range([0, boundsWidth]).domain(xExtent);

  var histogram = d3
    .histogram()
    .domain(yScale.domain())
    .thresholds(yScale.ticks(20)) // Important: how many bins approx are going to be made? It is the 'resolution' of the violin plot
    .value((d) => d);

  // Color scale
  const colorScale = d3
    .scaleLinear()
    .domain(yExtent)
    .range([colors.xppink.main, colors.xpblue.main]);

  // Initialize objects to store sums and counts for dates and scores
  let groupDateSums = {},
    groupScoreSums = {},
    groupCounts = {};

  processedData.forEach((d) => {
    // Check if the group already exists in the sum/count objects
    if (!groupDateSums.hasOwnProperty(d.group)) {
      groupDateSums[d.group] = 0;
      groupScoreSums[d.group] = 0;
      groupCounts[d.group] = 0;
    }

    // Add the current item's date (in milliseconds) and score to the group sums
    groupDateSums[d.group] += d.x.getTime();
    groupScoreSums[d.group] += d.y;

    // Increment the count for the current group
    groupCounts[d.group]++;
  });

  // Calculate the mean date and score for each group
  const groupMeans = Object.keys(groupCounts).map((group) => ({
    group,
    meanX: new Date(groupDateSums[group] / groupCounts[group]), // Mean date
    meanY: groupScoreSums[group] / groupCounts[group], // Mean score
  }));

  // Sort the group means by the mean date in ascending order
  groupMeans.sort((a, b) => a.meanX - b.meanX);

  var sumstat = d3
    .nest()
    .key(function (d) {
      return d.added_date;
    })
    .rollup(function (d) {
      // For each key..
      const input = d.map(function (g) {
        return g.score;
      }); // Keep the variable called Sepal_Length
      const bins = histogram(input); // And compute the binning on it.
      return bins;
    })
    .entries(data);

  var maxNum = 0;

  for (let i in sumstat) {
    const allBins = sumstat[i].value;
    const lengths = allBins.map(function (a) {
      return a.length;
    });
    const longuest = d3.max(lengths);

    if (longuest > maxNum) {
      maxNum = longuest;
    }
  }

  if (!brushSumstat) {
    setBrushSumstat(sumstat);
  }

  var xNum = d3.scaleLinear().range([0, 80]).domain([-maxNum, maxNum]);
  svg.selectAll(".violin").remove();

  const svgDefs = svg.append("defs");

  const gradient = svgDefs
    .append("linearGradient")
    .attr("id", "violinGradient")
    .attr("gradientUnits", "userSpaceOnUse") // This makes the gradient scale to match your data scale
    .attr("x1", 0)
    .attr("y1", yScale(yExtent[0])) // Top of the yScale
    .attr("x2", 0)
    .attr("y2", yScale(yExtent[1])); // Bottom of the yScale

  // Assuming colors.xppink.main and colors.xpblue.main are your start and end colors
  gradient.append("stop").attr("offset", "0%").attr("stop-color", colors.xppink.main);

  gradient.append("stop").attr("offset", "100%").attr("stop-color", colors.xpblue.main);

  if (isViolinVisible) {
    svg
      .selectAll("myViolin")
      .data(sumstat)
      .enter()
      .append("g")
      .attr("class", "violin")
      .attr("transform", function (d) {
        return `translate(${xScale(new Date(d.key)) - xNum(0) + MARGIN.left + 8}, 60)`;
      })
      .append("path")
      .datum(function (d) {
        return d.value;
      })
      .style("stroke", "none")
      .style("fill", "url(#violinGradient)")
      .attr(
        "d",
        d3
          .area()
          .x0(xNum(0))
          .x1(function (d) {
            return xNum(d.length);
          })
          .y(function (d) {
            return yScale(d.x0);
          })
          .curve(d3.curveCatmullRom) // Smooth line for the violin appearance
      );
  }

  // Build the shapes (circles)
  const allShapes = processedData.map((d, i) => (
    <circle
      key={i}
      r={5}
      cx={xScale(d.x)}
      cy={yScale(d.y)}
      className={styles.scatterplotCircle}
      stroke={colorScale(d.y)}
      fill={colorScale(d.y)}
      fillOpacity={0.5}
    />
  ));

  // Toggle visibility function
  const toggleMarksVisibility = () => {
    setMarksVisible(!marksVisible);
  };

  useEffect(() => {
    if (!brushSumstat) return;

    const svg = select(svgRef.current);

    const xScale = scaleLinear()
      .domain([0, brushSumstat.length])
      .range([width - 50, 10]);

    svg
      .selectAll(".brush-bar")
      .data(brushSumstat)
      .join("rect")
      .attr("class", "brush-bar")
      .attr("fill", (d, index) => {
        const isInRange =
          selection && index >= Math.round(selection[1]) && index < Math.round(selection[0]);

        return !isInRange ? colors.light.main : "url(#violinGradient)";
      })

      .attr("y", 5)
      .attr("x", (d, index) => xScale(index) - 100 || 0)
      .attr("height", 20) // subtract twice the padding from the height
      .attr("width", (d) => 20 || 0) //||xScale(d.value) || 0)
      .attr("rx", 5) // Add this line for x-axis corner radius
      .attr("ry", 5); // Add this line for y-axis corner radius

    const brush = d3
      .brushX()
      .extent([
        [0, 0],
        [width - 50, 30],
      ]) // Define the extent for the brush along the x-axis
      .on("start brush end", function () {
        const event = d3.event;

        // Access selection range
        const selectionRange = event.selection.map(xScale.invert);

        // Handle logic for selection
        if (selectionRange) {
          // Do something with the selection range
          setSelection(selectionRange);
        }
      });

    svg
      .select(".brush")
      .call(brush)
      .selectAll(".selection")
      .style("stroke", rgba(colors.secondary.main, 0.8)) // Change the stroke color of the selection
      .style("fill", colors.light.main) // Change the fill color of the selection
      .style("fill-opacity", 0.2); // Change the fill opacity of the selection

    svg
      .select(".brush")
      .call(brush)
      .selectAll(".handle")
      .style("fill", rgba(colors.secondary.main, 0.8)) // Change the fill color of the handles
      .style("stroke", rgba(colors.secondary.main, 0.8)) // Change the stroke color of the handles
      .style("stroke-width", "1.5px"); // Change the stroke width of the handles

    if (previousSelection === selection) {
      if (selection) {
        svg.select(".brush").call(brush).call(brush.move, selection.map(xScale));
      }
    }
  }, [data, brushSumstat, previousSelection, width, selection]);

  return (
    <>
      <div ref={svgContainer}>
        <svg width={width} height={height} id="my_dataviz">
          <g transform={`translate(${MARGIN.left}, ${MARGIN.top})`}>
            <AxisLeft yScale={yScale} pixelsPerTick={40} width={boundsWidth} />
            <g transform={`translate(0, ${boundsHeight})`}>
              <AxisBottom
                xScale={xScale}
                pixelsPerTick={75} // Adjust based on your layout
                tickFormat={d3.timeFormat("%Y-%m-%d")} // Format dates as "YYYY-MM-DD"
                height={boundsHeight}
              />
            </g>
            {allShapes}
            {marksVisible &&
              groupMeans.map((group, i) => (
                <g key={i} transform={`translate(${xScale(group.meanX)}, ${yScale(group.meanY)})`}>
                  <line
                    x1={-MARK_SIZE}
                    y1={-MARK_SIZE}
                    x2={MARK_SIZE}
                    y2={MARK_SIZE}
                    stroke={darkMode ? "white" : "black"}
                  />
                  <line
                    x1={-MARK_SIZE}
                    y1={MARK_SIZE}
                    x2={MARK_SIZE}
                    y2={-MARK_SIZE}
                    stroke={darkMode ? "white" : "black"}
                  />
                </g>
              ))}
            {marksVisible && (
              <path
                d={d3.line()(groupMeans.map((group) => [xScale(group.meanX), yScale(group.meanY)]))}
                fill="none"
                stroke={darkMode ? "white" : "black"}
              />
            )}
          </g>
          {/* Adjusted Legend Positioning */}
          <text
            x={width - MARGIN.right} // Adjust for the right margin
            y={MARGIN.top - 30} // Position above the top margin
            textAnchor="end"
            style={{
              cursor: "pointer",
              fill: darkMode ? "white" : "black",
              userSelect: "none",
              fontSize: "14px",
            }} // Prevent text selection on click
            onClick={toggleMarksVisibility}
          >
            &#x2715; Mean Value
          </text>
        </svg>
      </div>
      <XBox width="40px" sx={{ marginTop: "10px", position: "relative", pr: 5 }}>
        <svg
          ref={svgRef}
          height="50px"
          width={width - 50}
          style={{ transform: "translate(30px, 0)" }}
        >
          <g className="brush" />
          <g className="x-axis" />
          <g className="y-axis" />
        </svg>
      </XBox>
    </>
  );
};

Scatterplot.propTypes = {
  width: PropTypes.number.isRequired,
  height: PropTypes.number.isRequired,
  isViolinVisible: PropTypes.bool,
  data: PropTypes.arrayOf(
    PropTypes.shape({
      scenario_id: PropTypes.string.isRequired,
      created: PropTypes.string.isRequired,
      score: PropTypes.number.isRequired,
    })
  ).isRequired,
};

export default Scatterplot;
