import React, { useEffect, useState, useRef, useContext } from "react";
import PropTypes from "prop-types";
import * as d3 from "d3";
import colors from "assets/theme/base/colors";
import XBox from "components/XBox";
import { useModel } from "hooks";
import { useXplainableController } from "context";

const PerformanceChart = ({ data, type }) => {
  const [controller] = useXplainableController();
  const { darkMode } = controller;
  const { collapsedWidth } = useModel();
  const svgRef = useRef();
  const svgContainer = useRef(null); // The PARENT of the SVG
  const [width, setWidth] = useState(550);
  const [height, setHeight] = useState(300);
  const margin = { top: 20, right: 20, bottom: 60, left: 40 };

  // This function calculates width and height of the container
  const getSvgContainerSize = () => {
    const newWidth = svgContainer.current.clientWidth;
    setWidth(newWidth);
  };

  useEffect(() => {
    // detect 'width' and 'height' on render
    getSvgContainerSize();
    // listen for resize changes, and detect dimensions again when they change
    window.addEventListener("resize", getSvgContainerSize);
    // cleanup event listener
    return () => window.removeEventListener("resize", getSvgContainerSize);
  }, [collapsedWidth]);

  const xKey = type === "ROC" ? "fpr" : "recall";
  const yKey = type === "ROC" ? "tpr" : "precision";
  const xAxisLabel = type === "ROC" ? "False Positive Rate" : "Recall";
  const yAxisLabel = type === "ROC" ? "True Positive Rate" : "Precision";

  useEffect(() => {
    if (!data[xKey] || data[xKey].length === 0 || !data[yKey] || data[yKey].length === 0) {
      return;
    }

    const svg = d3.select(svgRef.current);

    svg.attr("width", width).attr("height", height);

    // Scales
    const xScale = d3
      .scaleLinear()
      .domain([d3.min(data[xKey]), d3.max(data[xKey])])
      .range([margin.left, width - margin.right]);

    const yScale = d3
      .scaleLinear()
      .domain([d3.min(data[yKey]), d3.max(data[yKey])])
      .range([height - margin.bottom, margin.top]);

    // Axes
    const xAxis = d3.axisBottom(xScale).ticks(11);
    const yAxis = d3.axisLeft(yScale);

    const xAxisGroup = svg.select(".x-axis");
    const yAxisGroup = svg.select(".y-axis");

    if (!xAxisGroup.empty()) {
      xAxisGroup
        .attr("transform", `translate(0, ${height - margin.bottom})`)
        .call(xAxis)
        .selectAll(".tick text")
        .attr("class", "x-tick");
    }

    if (!yAxisGroup.empty()) {
      yAxisGroup.attr("transform", `translate(${margin.left}, 0)`).call(yAxis);
    }

    // Add axis labels
    let xAxisLabelGroup = svg.select(".x-label");
    let yAxisLabelGroup = svg.select(".y-label");

    if (xAxisLabelGroup.empty()) {
      xAxisLabelGroup = svg
        .append("text")
        .attr("class", "x-label")
        .attr("x", width - margin.right)
        .attr("y", height - margin.bottom)
        .attr("dy", "-0.5em")
        .attr("font-size", "12px")
        .attr("text-anchor", "end");
    }
    xAxisLabelGroup
      .attr("x", width - margin.right) // Update x position
      .attr("y", height - margin.bottom) // Update y position
      .text(xAxisLabel);

    if (yAxisLabelGroup.empty()) {
      yAxisLabelGroup = svg
        .append("text")
        .attr("class", "y-label")
        .attr("transform", "rotate(-90)")
        .attr("y", 30)
        .attr("dy", "2em")
        .attr("font-size", "12px")
        .attr("dx", "-1.5em") // shift 1em to the right
        .attr("text-anchor", "end");
    }
    yAxisLabelGroup.text(yAxisLabel);

    // Line generator
    const lineGenerator = d3
      .line()
      .x((d) => xScale(d.x))
      .y((d) => yScale(d.y));

    const dataPoints = data[xKey].map((d, i) => ({
      x: d,
      y: data[yKey][i],
    }));

    // Update line with a transition
    svg.select(".line-path").datum(dataPoints).transition().duration(500).attr("d", lineGenerator);

    // Remove existing dashed line, if any
    svg.select(".equality-line").remove();

    // Draw dashed x=y line for non-ROC type
    if (type === "ROC") {
      const equalityLineGenerator = d3
        .line()
        .x((d) => xScale(d.x))
        .y((d) => yScale(d.y));

      const equalityDataPoints = d3.range(0, 1, 0.01).map((d) => ({ x: d, y: d }));

      svg
        .append("path")
        .datum(equalityDataPoints)
        .attr("fill", "none")
        .attr("stroke", "grey")
        .attr("stroke-width", 1.5)
        .attr("stroke-dasharray", "4,2")
        .attr("d", equalityLineGenerator)
        .attr("class", "equality-line") // add unique class name
        .attr("opacity", 0) // start with opacity 0
        .transition()
        .delay(300) // add delay of 500 ms
        .attr("opacity", 1); // transition to full opacity
    }

    svg.selectAll("text").style("fill", darkMode ? colors.white.main : colors.black.main);
    svg.selectAll(".domain").style("stroke", darkMode ? colors.white.main : colors.black.main);
  }, [data, width, height, type, controller]);

  return (
    <XBox>
      <XBox ref={svgContainer}>
        <svg ref={svgRef}>
          <g className="x-axis" />
          <g className="y-axis" />
          <path className="line-path" fill="none" stroke={colors.xpblue.main} strokeWidth="2" />
        </svg>
      </XBox>
    </XBox>
  );
};

export default PerformanceChart;

PerformanceChart.propTypes = {
  data: PropTypes.shape({
    recall: PropTypes.arrayOf(PropTypes.number),
    precision: PropTypes.arrayOf(PropTypes.number),
    fpr: PropTypes.arrayOf(PropTypes.number),
    tpr: PropTypes.arrayOf(PropTypes.number),
  }).isRequired,
  type: PropTypes.oneOf(["ROC", "PR"]).isRequired,
};

PerformanceChart.defaultProps = {
  data: {
    recall: [],
    precision: [],
    fpr: [],
    tpr: [],
  },
};
