import React, { useState, useRef, useEffect, useContext } from "react";
import { Slider } from "@mui/material";
import * as d3 from "d3";

import XBox from "components/XBox";
import XTypography from "components/XTypography";

import PropTypes from "prop-types";

import colors from "assets/theme/base/colors";
import { useModel } from "hooks";

import "assets/css/tooltip.css";
import { useXplainableController } from "context";

function groupAndSum(data, binSize) {
  let bins = [];
  for (let i = 0; i < data.length; i += binSize) {
    let end = Math.min(i + binSize, data.length);
    let sum = 0;
    for (let j = i; j < end; j++) {
      sum += data[j];
    }
    bins.push({
      x0: i,
      x1: end,
      sum: sum,
    });
  }
  return bins;
}

function ProbabilityChart({ margin, setThreshold, dataset, collapsedWidth: propCollapsedWidth }) {
  const { collapsedWidth, probabilityData, calibrationMap } = useModel();

  const [controller] = useXplainableController();
  const { darkMode } = controller;

  const reportData = {
    train: {
      0: [],
      1: [],
    },
    validation: {
      0: [],
      1: [],
    },
  };

  const [binNumber, setBinNumber] = useState(10);
  const [width, setWidth] = useState(0);
  const [height, setHeight] = useState(250);
  const svgRef = useRef(null);
  const svgContainer = useRef(null); // The PARENT of the SVG
  let pos;

  const marks = [
    { value: 10, label: "10" },
    { value: 20, label: "20" },
    { value: 50, label: "50" },
    { value: 100, label: "100" },
  ];

  // The dataByClassAndType object is directly accessible
  const dataByClassAndType = probabilityData[dataset] || reportData;

  // Get all the classes (this will return an array of the class names)
  const classes = Object.keys(dataByClassAndType);

  //Get the labels for the plot
  const positiveLabel = classes[0];
  const negativeLabel = classes[1];

  // This function calculates width and height of the container
  const getSvgContainerSize = () => {
    const newWidth = svgContainer.current.clientWidth - margin.left - margin.right;
    setWidth(newWidth);
  };

  // Slider change handler.
  const handleBinNumberChange = (_, newValue) => {
    setBinNumber(newValue);
    FlushChart();
    DrawChart(probabilityData, margin);
  };

  useEffect(() => {
    // detect 'width' and 'height' on render
    getSvgContainerSize();
    // listen for resize changes, and detect dimensions again when they change
    window.addEventListener("resize", getSvgContainerSize);
    // cleanup event listener
    return () => window.removeEventListener("resize", getSvgContainerSize);
  }, [collapsedWidth]);

  useEffect(() => {
    FlushChart();
    DrawChart(probabilityData, margin);
  }, [probabilityData, dataset, height, width, controller]);

  const FlushChart = () => {
    d3.select(svgRef.current).selectAll("*").remove();
  };

  const DrawChart = (data, margin) => {
    //Create svg elements
    const svg = d3
      .select(svgRef.current)
      .attr("width", width + margin.left + margin.right)
      .attr("height", height + margin.top + margin.bottom + 30)
      .append("g")
      .attr("transform", "translate(" + margin.left + "," + margin.top + ")");

    // Add X axis
    // X axis: scale and draw:
    var x = d3.scaleLinear().domain([0, 100]).range([0, width]);

    svg
      .append("g")
      .attr("transform", "translate(0," + height + ")")
      .attr("stroke-width", 0.7) // make horizontal tick thinner and lighter so that line paths can stand out
      .attr("opacity", 0.8)
      .call(d3.axisBottom(x))
      .selectAll("path, line, text")
      .style("stroke", darkMode ? colors.dark.main : colors.black.main);

    // //Add the Score text to the plot
    // svg
    //   .append("text")
    //   .attr("class", "x label")
    //   .attr("text-anchor", "start")
    //   .attr("x", width / 2.2)
    //   .attr("y", height + 50)
    //   .text("Score")
    //   .style("fill", darkMode ? colors.dark.main : colors.black.main);

    // set the parameters for the histogram
    const histogram = d3
      .histogram()
      .value(function (d) {
        return d;
      })
      .domain(x.domain())
      .thresholds(x.ticks(binNumber));

    console.log("The histogram data is", dataByClassAndType[classes[0]]);
    console.log("The bin is", binNumber);

    const binSize = 100 / binNumber; // Based on your range of 0-100
    const bins1 = groupAndSum(dataByClassAndType[classes[0]], binSize);
    const bins2 = groupAndSum(dataByClassAndType[classes[1]], binSize);

    //Return the max value of both bins
    const max1 = d3.max(bins1, function (d) {
      return d.sum;
    });
    const max2 = d3.max(bins2, function (d) {
      return d.sum;
    });
    const maxY = Math.max(max1, max2);

    //Add the linear y scale to the plot
    const y = d3.scaleLinear().range([height, 0]).domain([0, maxY]);

    const tooltip = d3.select("#tooltip");

    const onMouseEnter = (data, color) => {
      tooltip.select("#count").text(data.sum);

      const parentContainer = document.getElementById("probability-chart");
      const parentContainerHeight = parentContainer.getBoundingClientRect().height;

      const barElem = event.target.getBoundingClientRect();

      tooltip.classed("bottom-arrow", true);

      tooltip
        .style(
          "transform",
          `translate(${x(data.x0) + barElem.width / 2 + 10}px, ${
            parentContainerHeight - 120 - barElem.height
          }px`
        )
        .style("opacity", 0.9)
        .style("background", color)
        .style("color", colors.white.main); // Setting the text color

      // Set the alpha color for the ::before pseudo-element
      const styleElement = document.createElement("style");
      styleElement.innerHTML = `
      .tooltip::before {
          border-top: 11px solid ${color};
      }
    `;

      document.head.appendChild(styleElement);
    };

    const onMouseLeave = () => {
      tooltip.style("opacity", 0);
    };

    svg
      .append("g")
      .call(d3.axisLeft(y))
      .attr("stroke-width", 0.7) // make horizontal tick thinner and lighter so that line paths can stand out
      .attr("opacity", 0.8)
      .selectAll("path, line, text")
      .style("stroke", darkMode ? colors.dark.main : colors.black.main);

    const y1 = d3.scaleLinear().range([height, 0]);
    y1.domain([0, 1]);

    svg
      .append("g")
      .attr("transform", "translate(" + width + " ,0)")
      .attr("stroke-width", 0.7) // make horizontal tick thinner and lighter so that line paths can stand out
      .attr("opacity", 0.8)
      .call(d3.axisRight(y1))
      .selectAll("path, line, text")
      .style("stroke", darkMode ? colors.dark.main : colors.black.main);

    //Add x label
    svg
      .append("text")
      .attr("class", "x label")
      .attr("text-anchor", "start")
      .attr("x", -30)
      .attr("y", 15)
      .attr("font-size", "10px")
      .attr("transform", "rotate(-90)") // added this line
      .text("Count")
      .style("fill", darkMode ? colors.dark.main : colors.black.main);

    //Add x label
    svg
      .append("text")
      .attr("class", "x label")
      .attr("text-anchor", "start")
      .attr("x", width - 30)
      .attr("y", 0)
      .attr("font-size", "10px")
      .attr("transform", "rotate(-90," + width + ",10)") // added this line
      .text("Probability")
      .style("fill", darkMode ? colors.dark.main : colors.black.main);

    svg
      .selectAll("rect")
      .data(bins1)
      .enter()
      .append("rect")
      .attr("x", 1)
      .attr("transform", function (d) {
        return "translate(" + x(d.x0) + "," + y(d.sum) + ")";
      })
      .attr("width", function (d) {
        return x(d.x1) - x(d.x0) == 0 ? 0 : x(d.x1) - x(d.x0) - 1;
      })
      .attr("height", function (d) {
        return height - y(d.sum);
      })
      .style("fill", "#E14067")
      .style("opacity", 0.7)
      .on("mouseenter", (data) => onMouseEnter(data, "rgb(225, 64, 103)"))
      .on("mouseleave", onMouseLeave);

    svg
      .selectAll("rect2")
      .data(bins2)
      .enter()
      .append("rect")
      .attr("x", 1)
      .attr("transform", function (d) {
        return "translate(" + x(d.x0) + "," + y(d.sum) + ")";
      })
      .attr("width", function (d) {
        return x(d.x1) - x(d.x0) == 0 ? 0 : x(d.x1) - x(d.x0) - 1;
      })
      .attr("height", function (d) {
        return height - y(d.sum);
      })
      .style("fill", "#0080EA")
      .style("opacity", 0.7)
      .on("mouseenter", (data) => onMouseEnter(data, colors.xpblue.main))
      .on("mouseleave", onMouseLeave);

    //Append the Probability Line
    let lineData = d3.map(calibrationMap).entries();
    let mainLine = svg
      .append("path")
      .datum(lineData)
      .attr("class", "main--line")
      .attr("fill", "none")
      .attr("stroke", "#A9A9A9")
      .attr("stroke-width", 2)
      .attr(
        "d",
        d3
          .line()
          .x(function (d) {
            return x(d.key);
          })
          .y(function (d) {
            return y1(d.value);
          })
      );

    // //Add vertical tooltip line
    let verticalLine = svg
      .append("line")
      .attr("x1", 0)
      .attr("x2", 0)
      .attr("y1", 0)
      .attr("y2", height)
      .attr("stroke", "#334155")
      .attr("class", "verticalLine")
      .style("stroke-dasharray", "5,5"); //dashed array for line;

    let horizontalLine = svg
      .append("line")
      .attr("stroke", "#334155")
      .attr("class", "horizontalLine")
      .style("stroke-dasharray", "5,5"); //dashed array for line;

    let fixedLine = svg
      .append("line")
      .attr("y1", 0)
      .attr("y2", height)
      .attr("stroke", "#0f172a")
      .attr("stroke-width", "2px")
      .attr("class", "fixedLine");

    d3.select(".fixedLine").attr("x1", 125).attr("x2", 125);

    let circle = svg
      .append("circle")
      .attr("class", "focus")
      .attr("opacity", 0)
      .attr("r", 5)
      .attr("fill", "white")
      .attr("stroke", "#334155");

    // Handmade legend
    svg
      .append("circle")
      .attr("cx", 120)
      .attr("cy", 30)
      .attr("r", 6)
      .style("fill", colors.xppink.main);
    svg
      .append("circle")
      .attr("cx", 120)
      .attr("cy", 60)
      .attr("r", 6)
      .style("fill", colors.xpblue.main);
    svg
      .append("text")
      .attr("x", 135)
      .attr("y", 30)
      .text(negativeLabel)
      .style("font-size", "15px")
      .attr("alignment-baseline", "middle")
      .attr("fill", "#A9A9A9");
    svg
      .append("text")
      .attr("x", 135)
      .attr("y", 60)
      .text(positiveLabel)
      .style("font-size", "15px")
      .attr("alignment-baseline", "middle")
      .attr("fill", "#A9A9A9");

    svg
      .selectAll("text")
      .style("fill", darkMode ? colors.white.main : "")
      .style("stroke", darkMode ? colors.white.main : "");

    svg.selectAll(".domain").style("stroke", darkMode ? colors.white.main : "");
    svg.selectAll(".fixedLine").style("stroke", darkMode ? colors.white.main : "");

    d3.select(svgContainer.current)
      .on("mousemove", function () {
        let xLoc = d3.mouse(d3.select(svgContainer.current).node())[0];

        if (xLoc > margin.left && xLoc < width + margin.left) {
          d3.select(".verticalLine").attr("display", "null");
          d3.select(".horizontalLine").attr("display", "null");
          //Update the vertical line
          var xPos = xLoc - margin.left;

          d3.select(".verticalLine").attr("transform", function () {
            return "translate(" + xPos + ",0)";
          });

          //Find location of x line
          var pathLength = mainLine.node().getTotalLength();
          var xVal = xPos;
          var beginning = xVal,
            end = pathLength;
          while (true) {
            let target = Math.floor((beginning + end) / 2);
            pos = mainLine.node().getPointAtLength(target);
            if ((target === end || target === beginning) && pos.x !== xVal) {
              break;
            }
            if (pos.x > xVal) end = target;
            else if (pos.x < xVal) beginning = target;
            else break; //position found
          }
          circle.attr("opacity", 1).attr("cx", xVal).attr("cy", pos.y);

          //Update the horizontal line
          d3.select(".horizontalLine")
            .attr("transform", function () {
              return "translate(" + xPos + "," + pos.y + ")";
            })
            .attr("x1", 0)
            .attr("x2", width - xPos);
        }
      })
      .on("click", function () {
        //Return the cLoc of the mouse
        let xLoc = d3.mouse(d3.select(svgContainer.current).node())[0];

        d3.select(".fixedLine")
          .attr("x1", function () {
            return xLoc < margin.left
              ? margin.left
              : xLoc > width + margin.left
              ? width
              : xLoc - margin.left;
          })
          .attr("x2", function () {
            return xLoc < margin.left
              ? margin.left
              : xLoc > width + margin.left
              ? width
              : xLoc - margin.left;
          });

        var xPos = xLoc - margin.left - 10;
        var pathLength = mainLine.node().getTotalLength();
        var xVal = xPos;
        var beginning = xVal,
          end = pathLength,
          target;
        while (true) {
          target = Math.floor((beginning + end) / 2);
          pos = mainLine.node().getPointAtLength(target);
          if ((target === end || target === beginning) && pos.x !== xVal) {
            break;
          }
          if (pos.x > xVal) end = target;
          else if (pos.x < xVal) beginning = target;
          else break; //position found
        }

        //Set the global state of the click value
        // let yCutoff = y1.invert(pos.y);
        let xCutoff = x.invert(pos.x);

        // console.log("The yCutoff value is", yCutoff);

        setThreshold(xCutoff);
      });
  };

  const valueLabelFormat = (value) => {
    const mark = marks.find((mark) => mark.value === value);
    return mark ? mark.label : "";
  };

  return (
    <>
      <XBox display="flex" flexDirection="column">
        <XBox display="flex" width="40%" flexDirection="column" gap={1}>
          <XTypography variant="h5" fontWeight="bold" fontSize="18px">
            Score
          </XTypography>
          <XTypography variant="h5" fontWeight="bold" color="xpblue" fontSize="24px">
            {binNumber + " Bins"}
          </XTypography>
        </XBox>
        <XBox width={"100%"} pl={3} pr={1}>
          <Slider
            // sx={{
            //   "& .MuiSlider-markLabel": {
            //     color: "red",
            //   },
            // }}
            defaultValue={10}
            valueLabelDisplay="auto"
            step={null}
            marks={marks}
            min={marks[0].value}
            max={marks[marks.length - 1].value}
            onChange={handleBinNumberChange}
            valueLabelFormat={valueLabelFormat}
          />
        </XBox>
      </XBox>
      <div id="tooltip" className="tooltip">
        <div className="tooltip-framework">
          <span id="framework" />
        </div>
        <div className="tooltip-value">
          <span id="count" />
        </div>
      </div>
      <div ref={svgContainer}>
        <svg ref={svgRef} id="probability-chart"></svg>
      </div>
    </>
  );
}

export default ProbabilityChart;

// Typechecking props for the Feature Chart
ProbabilityChart.propTypes = {
  data: PropTypes.any,
  dataset: PropTypes.string,
  selectedValue: PropTypes.string,
  metric: PropTypes.any,
  metricValue: PropTypes.any,
  threshold: PropTypes.number,
  margin: PropTypes.object,
  percent: PropTypes.bool,
  onBarClick: PropTypes.func,
  display: PropTypes.bool,
  value: PropTypes.string,
  tooltip: PropTypes.any,
  setThreshold: PropTypes.any,
  onWidthChange: PropTypes.func,
  collapsedWidth: PropTypes.bool,
};
