import React, { useEffect, useState } from "react";
import PropTypes from "prop-types";
import { Card } from "@mui/material";
import { useAuth0 } from "@auth0/auth0-react";
import { useModel } from "hooks";
import { ConfusionChart } from "shared/models/BinaryPerformance/components/ConfusionChart";
import XBox from "components/XBox";
import LoadingSpinner from "shared/Animations/LoadingAnimation";
import { useBinaryPerformanceQuery } from "api/query";

const ConfusionMatrix = React.memo(({ threshold, setMetrics }) => {
  const { logout } = useAuth0();
  const { selectedPartition, profileDataLoading } = useModel();

  const [selectedDataset, setSelectedDataset] = useState({
    value: "train",
    label: "Train",
  });
  const [confusionData, setConfusionData] = useState({});
  const [filteredData, setFilteredData] = useState([]);

  console.log("The selected partition is", selectedPartition.value)

  // Destructure data, isLoading, and error from React Query
  const { data, isLoading, error } = useBinaryPerformanceQuery(
    selectedPartition?.value,
    logout
  );

  console.log("The confusion data is", confusionData)

  // Filter the data based on the selected dataset (train or validation)
  useEffect(() => {
    if (data) {
      const filtered = selectedDataset.value === "validation"
        ? data.evaluation?.validation
        : data.evaluation?.train;

      setFilteredData(filtered || []);
    }
  }, [selectedDataset, data]);

  // Update confusion data and metrics whenever threshold or filteredData changes
  useEffect(() => {
    if (
      filteredData &&
      filteredData.scores &&
      filteredData.auc_pr &&
      filteredData.mcc &&
      filteredData.log_loss &&
      filteredData.roc_auc
    ) {
      const roundedThreshold = Math.round(threshold);
      const scores = filteredData.scores[roundedThreshold];

      if (scores) {
        const metrics = [
          { title: "F1", value: scores.f1.toFixed(3) },
          { title: "Accuracy", value: scores.accuracy.toFixed(3) },
          { title: "Precision", value: scores.precision.toFixed(3) },
          { title: "Recall", value: scores.recall.toFixed(3) },
          { title: "AUC PR", value: filteredData.auc_pr.toFixed(3) },
          { title: "MCC", value: filteredData.mcc.toFixed(3) },
          { title: "Log Loss", value: filteredData.log_loss.toFixed(3) },
          { title: "ROC AUC", value: filteredData.roc_auc.toFixed(3) },
        ];

        const filteredConfusionData = {
          fn: scores.fn,
          fp: scores.fp,
          tn: scores.tn,
          tp: scores.tp,
        };

        setConfusionData(filteredConfusionData);
        setMetrics(metrics);
      }
    }
  }, [threshold, filteredData, setMetrics]);

  // Handle loading state
  if (profileDataLoading || isLoading) {
    return (
      <Card
        sx={{
          display: "flex",
          width: "100%",
          flexDirection: "row",
          height: "500px",
        }}
      >
        <XBox display="flex" width="100%" alignItems="center" justifyContent="center">
          <LoadingSpinner size={50} animationType="pulse" />
        </XBox>
      </Card>
    );
  }

  // Handle error state
  if (error) {
    return (
      <Card
        sx={{
          display: "flex",
          width: "100%",
          flexDirection: "row",
          height: "500px",
          alignItems: "center",
          justifyContent: "center",
          padding: 2,
        }}
      >
        <XBox color="error.main">
          {error.message || "Error loading performance data."}
        </XBox>
      </Card>
    );
  }

  // Render the ConfusionChart with the processed confusionData
  return (
    <ConfusionChart
      id="confusion--plot"
      confusionData={confusionData}
      threshold={threshold}
      margin={{
        top: 0,
        right: 0,
        bottom: 0,
        left: 0,
      }}
    />
  );
});

ConfusionMatrix.propTypes = {
  threshold: PropTypes.number.isRequired,
  setThreshold: PropTypes.func.isRequired,
  setMetrics: PropTypes.func.isRequired,
};

export default ConfusionMatrix;
