import { useApiKey } from "components/Authorisation/ApiKeyContext";
import { useModel } from "hooks";
import { ConfusionChart } from "shared/models/BinaryPerformance/components/ConfusionChart";
import { Card } from "@mui/material";

import React, { useEffect, useState } from "react";
import XBox from "components/XBox";
import LoadingSpinner from "shared/Animations/LoadingAnimation";
import PropTypes from "prop-types";

const hostUrl = process.env.REACT_APP_HOST_URL;

const ConfusionMatrix = React.memo(({ threshold, setThreshold, setMetrics }) => {
  const { apiKey, activeWorkspace } = useApiKey();
  const { model_id, selectedVersion, selectedPartition, profileDataLoading } = useModel();

  const [selectedDataset, setSelectedDataset] = useState({
    value: "train",
    label: "Train",
  });
  const [confusionData, setConfusionData] = useState({});
  const [filteredData, setFilteredData] = useState([]);
  const [performanceData, setPerformanceData] = useState([]);

  const getPerformanceData = async (model_id, version_id, partition_id) => {
    try {
      //Get the Partition ID
      const response = await fetch(
        `${hostUrl}/v1/organisations/${activeWorkspace?.organisation_id}/teams/${activeWorkspace?.team_id}/models/${model_id}/versions/${version_id}/partitions/${partition_id}/evaluation`,
        {
          method: "GET",
          headers: {
            "Content-Type": "application/json",
            api_key: apiKey,
          },
        }
      );

      if (!response.ok) {
        throw new Error(`HTTP error! status: ${response.status}`);
      }

      const data = await response.json();

      //Set the Performance data
      setPerformanceData(data[0]);
    } catch (error) {
      console.error("Error:", error);
    }
  };

  // Fetch the performance data from the model endpoint
  useEffect(() => {
    if (selectedVersion && selectedPartition) {
      getPerformanceData(model_id, selectedVersion?.value, selectedPartition?.value);
    }
  }, [model_id, selectedVersion, selectedPartition]);

  useEffect(() => {
    if (
      filteredData &&
      filteredData.scores &&
      filteredData.auc_pr &&
      filteredData.mcc &&
      filteredData.log_loss &&
      filteredData.roc_auc &&
      filteredData.roc &&
      filteredData.precision_recall_curve
    ) {
      //Build the scores dataset
      const scores = filteredData.scores[Math.round(threshold)]; //[Number(Math.round((threshold) * 100))];

      // console.log("The filtered Data is", filteredData)

      const metrics = [
        { title: "F1", value: scores.f1.toFixed(3) },
        { title: "Accuracy", value: scores.accuracy.toFixed(3) },
        { title: "Precision", value: scores.precision.toFixed(3) },
        { title: "Recall", value: scores.recall.toFixed(3) },
        { title: "AUC PR", value: filteredData.auc_pr.toFixed(3) },
        { title: "MCC", value: filteredData.mcc.toFixed(3) },
        { title: "Log Loss", value: filteredData.log_loss.toFixed(3) },
        { title: "ROC AUC", value: filteredData.roc_auc.toFixed(3) },
      ];

      const confusionData = Object.fromEntries(
        Object.entries(scores).filter(([k, v]) => ["fn", "fp", "tn", "tp"].includes(k))
      );

      setConfusionData(confusionData);
      setMetrics(metrics);
    }
  }, [threshold, filteredData]);

  useEffect(() => {
    const filterData = () => {
      if (selectedDataset.value === "validation") {
        return performanceData.evaluation?.validation;
      } else if (selectedDataset.value === "train") {
        return performanceData.evaluation?.train;
      }
    };

    const newData = filterData();

    // console.log("The filtered data is", newData)
    if (newData) {
      setFilteredData(newData);
    }
  }, [selectedDataset, performanceData]);

  if (profileDataLoading) {
    return (
      <Card
        sx={{
          display: "flex",
          width: "100%",
          flexDirection: "row",
          height: "500px",
        }}
      >
        <XBox display="flex" width="100%" alignItems="center" justifyContent="center">
          <LoadingSpinner size={50} animationType="pulse" />
        </XBox>
      </Card>
    );
  }

  return (
    <ConfusionChart
      id={"confusion--plot"}
      confusionData={confusionData}
      threshold={threshold}
      margin={{
        top: 0,
        right: 0,
        bottom: 0,
        left: 0,
      }}
    />
  );
});

export default ConfusionMatrix;

ConfusionMatrix.propTypes = {
  threshold: PropTypes.number,
  setThreshold: PropTypes.func,
  setMetrics: PropTypes.func,
};
