import React, { useState, useEffect } from 'react';
import { observer } from 'mobx-react';
import Box from '@mui/material/Box';
import {
  TextField,
  Typography,
  Table,
  TableBody,
  TableCell,
  TableContainer,
  TableHead,
  TableRow,
  Paper,
  IconButton,
  InputAdornment,
} from '@mui/material';
import Pagination from '@mui/material/Pagination';
import { Add } from '@mui/icons-material';
import DeleteIcon from '@mui/icons-material/Delete';
import VisibilityIcon from '@mui/icons-material/Visibility';
import FileDownloadIcon from '@mui/icons-material/FileDownload';
import { MenuItem, makeStyles } from '@material-ui/core';
import { useToasts } from 'react-toast-notifications';
import RefreshIcon from '@mui/icons-material/Refresh';
import { useForm, useFieldArray, useWatch, Controller } from 'react-hook-form';

import Button from '../../components/buttons/Button';
import { useStore } from '../../hooks/useStore';
import Flex from '../../components/utils/flex/Flex';
import { getTextOverflowEllipsisStyles } from '../../components/typography/utils';
import { CustomSlider } from '../../components/cards/newConversationCard/SelectModelDetails';
import ShowModelEvaluationResultsModal from '../../components/modal/ShowModelEvaluationResultModal';
import AWS from 'aws-sdk';

import {
  COLOR_BORDER_PRIMARY,
  COLOR_WHITE,
  GRAY_COLORS,
  COLOR_RED,
  COLOR_BORDER_SECONDARY,
  COLOR_ICON_PRIMARY,
} from '../../constants/colors';
import { JobStatus } from '../../constants/modelEvaluationJobStatus';

const useStyles = makeStyles({
  tableCell: {
    border: `1px solid ${GRAY_COLORS.GRAY_12}`,
  },
  customSliderDiv: {
    display: 'flex',
    flexDirection: 'column',
    marginBottom: '14px',
    width: '100%',
    gap: '1rem',
  },
  customSliderP: {
    fontSize: '14px',
    fontWeight: 500,
  },
});

const ModelEvaluationPage = () => {
  const {
    localizationStore: { i18next: i18n },
    modelStore: { createModelEvaluation, listModelEvaluation, showOutPutData },
  } = useStore();

  type FormInputs = {
    jobName: string;
    modelIdentifier: string;
    taskType: string;
    metrics: object[];
    max_gen_len: number;
    temperature: number;
    top_p: number;
  };

  const {
    control,
    handleSubmit,
    formState: { errors },
    reset,
    getValues,
    trigger,
  } = useForm({
    defaultValues: {
      jobName: '',
      modelIdentifier: '',
      taskType: '',
      temperature: 0,
      max_gen_len: 0,
      top_p: 0,
      metrics: [{ metricName: '', dataSet: '' }],
    },
    mode: 'onChange',
    reValidateMode: 'onChange',
  });

  const modelIdentifier = useWatch({ control, name: 'modelIdentifier' });

  const taskType = useWatch({ control, name: 'taskType' });

  const metrics = useWatch({ control, name: 'metrics' });

  const jobName = useWatch({ control, name: 'jobName' });

  const { fields, append, remove } = useFieldArray({
    control,
    name: 'metrics',
  });

  const onSubmit = async (data: FormInputs) => {
    try {
      await createModelEvaluation({
        jobName: data.jobName,
        modelIdentifier: data.modelIdentifier,
        metrics: data.metrics,
        taskType: data.taskType,
        max_gen_len: Number(data.max_gen_len),
        temperature: Number(data.temperature),
        top_p: Number(data.top_p),
      });
      await getAllModelEvaluationJob();
      addToast(`${i18n.t('adminPortal.modelevaluation.toast.success')}`, { appearance: 'success' });
      reset();
    } catch (error) {
      console.log('error while creating evaluation job : ', error);
      addToast(`${i18n.t('adminPortal.modelevaluation.toast.error')}`, { appearance: 'error' });
      reset();
    }
  };

  const initialStateOfModels = [{}];

  const [allEvaluationJob, setAllEvaluationJob] = useState(initialStateOfModels);
  const [showOutPutModal, setShowOutPutModal] = useState(false);
  const [evaluationResultData, setEvaluationResultData] = useState([{}]);
  const [currentJobData, setCurrentJobData] = useState({});
  const [page, setPage] = React.useState(1);

  const itemsPerPage = 10;
  const handleChangePage = (event: any, newPage: any) => {
    setPage(newPage);
  };

  const { addToast } = useToasts();
  const classes = useStyles();

  const modelDropDownValues = [
    {
      id: 1,
      value: 'arn:aws:bedrock:us-east-1::foundation-model/meta.llama2-13b-chat-v1',
      name: 'Llama 2 Chat 13B',
    },
    {
      id: 2,
      value: 'arn:aws:bedrock:us-east-1::foundation-model/meta.llama2-70b-chat-v1',
      name: 'Llama 2 Chat 70B',
    },
    {
      id: 3,
      value: 'arn:aws:bedrock:us-east-1::foundation-model/cohere.command-text-v14',
      name: 'Command',
    },
    {
      id: 4,
      value: 'arn:aws:bedrock:us-east-1::foundation-model/cohere.command-light-text-v14',
      name: 'Command Light',
    },
    {
      id: 5,
      value: 'arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-instant-v1',
      name: 'Claude Instant',
    },
    {
      id: 6,
      value: 'arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-v2:1',
      name: 'Claude v2.1',
    },
    {
      id: 7,
      value: 'arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-v2',
      name: 'Claude v2',
    },
    {
      id: 8,
      value: 'arn:aws:bedrock:us-east-1::foundation-model/ai21.j2-mid-v1',
      name: 'Jurassic-2 Mid',
    },
    {
      id: 9,
      value: 'arn:aws:bedrock:us-east-1::foundation-model/ai21.j2-ultra-v1',
      name: 'Jurassic-2 Ultra',
    },
  ];

  const taskTypesDropDownValues = [
    {
      id: 1,
      value: 'Generation',
      name: 'General text generation',
    },
    {
      id: 2,
      value: 'Summarization',
      name: 'Text summarization',
    },
    {
      id: 3,
      value: 'QuestionAndAnswer',
      name: 'Question and answer',
    },
    {
      id: 4,
      value: 'Classification',
      name: 'Text classification',
    },
  ];

  const metricNamesDropDownValues = [
    {
      id: 1,
      value: 'Builtin.Robustness',
      name: 'Robustness',
    },
    {
      id: 2,
      value: 'Builtin.Accuracy',
      name: 'Accuracy',
    },
    {
      id: 3,
      value: 'Builtin.Toxicity',
      name: 'Toxicity',
    },
  ];

  const dataSetsDropDownValues = [
    {
      id: 1,
      value: 'Builtin.RealToxicityPrompts',
      name: 'RealToxicityPrompts',
    },
    {
      id: 2,
      value: 'Builtin.BOLD',
      name: 'BOLD',
    },
    {
      id: 3,
      value: 'Builtin.T-REx',
      name: 'T-REx',
    },
    {
      id: 4,
      value: 'Builtin.WikiText2',
      name: 'WikiText2',
    },
    {
      id: 5,
      value: 'Builtin.Gigaword',
      name: 'Gigaword',
    },
    {
      id: 6,
      value: 'Builtin.NaturalQuestions',
      name: 'NaturalQuestions',
    },
    {
      id: 7,
      value: 'Builtin.BoolQ',
      name: 'BoolQ',
    },
    {
      id: 8,
      value: 'Builtin.TriviaQA',
      name: 'TriviaQA',
    },
    {
      id: 9,
      value: "Builtin.Women's Ecommerce Clothing Reviews",
      name: "Women's Ecommerce Clothing Reviews",
    },
  ];

  const handleShowOutput = async (jobData: any) => {
    try {
      let data = await showOutPutData(jobData);
      setEvaluationResultData(data);
      setCurrentJobData(jobData);
      setShowOutPutModal(true);
    } catch (error) {
      console.log(error, 'something went wrong');
    }
  };

  const handleDownload = async (jobData: any) => {
    try {
      const s3 = new AWS.S3({
        region: process.env.REACT_APP_AWS_REGION,
        accessKeyId: process.env.REACT_APP_AWS_ACCESS_KEY,
        secretAccessKey: process.env.REACT_APP_AWS_SECRET_KEY,
      });

      const jobArnNumber = jobData?.jobArn.substring(jobData.jobArn.lastIndexOf('/') + 1);
      const modelId = jobData.modelId.substring(jobData.modelId.lastIndexOf('/') + 1);

      const paramsForGetFileName = {
        Bucket: process.env.REACT_APP_MODEL_EVALUATION_BUCKET,
        Prefix: `output-data/${jobData.jobName}/${jobArnNumber}/models/${modelId}/taskTypes/${jobData.taskType}/datasets/`,
      };

      const getFileNamesPromise = () => {
        return new Promise((resolve, reject) => {
          s3.listObjectsV2(paramsForGetFileName as any, async (err: any, data: any) => {
            if (err) {
              console.error(err);
              reject(err);
            } else {
              const fileNames = data.Contents.map((obj: any) => obj.Key);
              resolve(fileNames);
            }
          });
        });
      };
      const fileLocation: any = await getFileNamesPromise();

      for (let i = 0; i < fileLocation.length; i++) {
        const file: string = fileLocation[i];

        const params = {
          Bucket: process.env.REACT_APP_MODEL_EVALUATION_BUCKET,
          Key: file,
        };

        const response = await s3.getObject(params as any).promise();

        // Create a Blob object from the S3 response for download
        const blob = new Blob([response?.Body as any], { type: response.ContentType });

        const url = window.URL.createObjectURL(blob);
        const link = document.createElement('a');
        link.href = url;
        link.setAttribute('download', `${jobData.jobName}-${new Date().getTime()}.jsonl` as string);
        document.body.appendChild(link);
        link.click();

        // Clean up temporary URL after download
        document.body.removeChild(link);
        window.URL.revokeObjectURL(url);
        addToast('file downloaded successfully', { appearance: 'success' });
      }
    } catch (error) {
      console.error(error);
      addToast('something went wrong', { appearance: 'error' });
    }
  };

  const getAllModelEvaluationJob = async () => {
    const allModels = await listModelEvaluation();
    setAllEvaluationJob(allModels);
  };

  const handleRfresh = async () => {
    await getAllModelEvaluationJob();
  };

  const filterDatasets = (taskType: string, metricName: string, dataSets: object[]) => {
    return dataSets.filter((item: any) => {
      switch (taskType) {
        case 'Classification':
          return item.id === 9;
        case 'Summarization':
          return item.id === 5;
        case 'QuestionAndAnswer':
          return [6, 7, 8].includes(item.id);
        case 'Generation':
          switch (metricName) {
            case 'Builtin.Accuracy':
              return item.id === 3;
            case 'Builtin.Robustness':
              return [2, 4].includes(item.id);
            case 'Builtin.Toxicity':
              return [1, 2].includes(item.id);
            default:
              return false;
          }
        default:
          return false;
      }
    });
  };

  useEffect(() => {
    getAllModelEvaluationJob();
  }, []);

  return (
    <>
      <Box
        style={{
          backgroundColor: COLOR_WHITE,
          padding: '30px',
          borderRadius: '8px',
          border: `1px solid ${COLOR_BORDER_PRIMARY}`,
        }}
      >
        <form onSubmit={handleSubmit(onSubmit)}>
          <Box display="flex" justifyContent="space-between" alignItems="center" gap="1rem" marginBottom="1rem">
            <Typography variant="h4" sx={{ fontSize: '22px', fontWeight: 700 }}>
              {i18n.t('adminPortal.modelevaluation.heading')}
            </Typography>
            <Box sx={{ display: 'flex', justifyContent: 'flex-end' }}>
              <div>
                <Button type="submit" variant="contained" component="span" onClick={handleSubmit(onSubmit)}>
                  <Add sx={{ marginRight: '8px' }} />
                  {i18n.t('adminPortal.modelevaluation.create')}
                </Button>
              </div>
            </Box>
          </Box>
          <Box display="flex" gap="1rem">
            <Box display="flex" flexDirection="column" gap={'1rem'} width="60%">
              <Controller
                control={control}
                name="jobName"
                rules={{
                  required: true,
                  pattern: {
                    value: /^[a-z0-9-]+$/,
                    message: 'Only Small letters, numbers, and hyphens allowed',
                  },
                }}
                render={({ field: { onChange, value, ref } }) => (
                  <TextField
                    label="Job name"
                    variant="outlined"
                    sx={{ width: '100%' }}
                    onChange={onChange}
                    value={value}
                    ref={ref}
                    InputProps={{
                      startAdornment: (
                        <InputAdornment position="start">{process.env.REACT_APP_CUSTOMER_NAME}-</InputAdornment>
                      ),
                    }}
                  />
                )}
              />
              {errors.jobName?.type === 'required' && <p style={{ color: COLOR_RED }}>Required!</p>}
              {errors.jobName?.type === 'pattern' && (
                <p style={{ color: COLOR_RED }}>Only Small letters, numbers, and hyphens are allowed</p>
              )}

              <Controller
                control={control}
                name="modelIdentifier"
                rules={{ required: true }}
                render={({ field: { onChange, value, ref } }) => (
                  <TextField select label="Model" sx={{ width: '100%' }} onChange={onChange} value={value} ref={ref}>
                    {modelDropDownValues?.map((item, index) => (
                      <MenuItem value={item.value} key={`filter-by-model-${index}`}>
                        <Flex sx={{}}>
                          <Typography variant={'subtitle2'} sx={getTextOverflowEllipsisStyles(1)}>
                            {i18n.t(item.name)}
                          </Typography>
                        </Flex>
                      </MenuItem>
                    ))}
                  </TextField>
                )}
              />

              {errors.modelIdentifier?.type === 'required' && <p style={{ color: COLOR_RED }}>Required!</p>}

              <Controller
                control={control}
                name="taskType"
                rules={{ required: true }}
                render={({ field: { onChange, value, ref } }) => (
                  <TextField select label="Task" sx={{ width: '100%' }} onChange={onChange} value={value} ref={ref}>
                    {taskTypesDropDownValues.map((item, index) => (
                      <MenuItem value={item.value} key={`filter-by-task-${index}`}>
                        <Flex sx={{}}>
                          <Typography variant={'subtitle2'} sx={getTextOverflowEllipsisStyles(1)}>
                            {i18n.t(item.name)}
                          </Typography>
                        </Flex>
                      </MenuItem>
                    ))}
                  </TextField>
                )}
              />
              {errors.taskType?.type === 'required' && <p style={{ color: COLOR_RED }}>Required!</p>}

              <Box
                sx={{
                  justifyContent: 'space-between',
                  width: '100%',
                  gap: '0.5rem',
                  border: `1px solid ${COLOR_BORDER_SECONDARY}`,
                  padding: '12px',
                  borderRadius: '8px',
                }}
              >
                <Typography variant="h6" sx={{ fontSize: '16px', fontWeight: 400 }}>
                  Metrics
                </Typography>
                {fields.map((field: any, index: number) => (
                  <Box key={field.id} sx={{ display: 'flex', gap: '1rem', marginTop: '1rem' }}>
                    <Box width={'50%'}>
                      <Controller
                        control={control}
                        name={`metrics.${index}.metricName`}
                        rules={{ required: true }}
                        render={({ field: { onChange, value, ref } }) => (
                          <TextField
                            label="Metric"
                            select
                            sx={{ width: '100%' }}
                            onChange={onChange}
                            value={value}
                            ref={ref}
                          >
                            {metricNamesDropDownValues
                              ?.filter(item => {
                                return taskType === 'Classification' ? item.id !== 3 : item;
                              })
                              .map((item, index) => (
                                <MenuItem
                                  disabled={metrics?.some(c => c.metricName === item.value)}
                                  value={item.value}
                                  key={`filter-by-metric-${index}`}
                                >
                                  <Flex sx={{}}>
                                    <Typography variant={'subtitle2'} sx={getTextOverflowEllipsisStyles(1)}>
                                      {i18n.t(item.name)}
                                    </Typography>
                                  </Flex>
                                </MenuItem>
                              ))}
                          </TextField>
                        )}
                      />
                      {errors?.metrics?.[index]?.metricName?.type === 'required' && (
                        <p style={{ color: COLOR_RED }}>Required!</p>
                      )}
                    </Box>

                    <Box width={'50%'}>
                      <Controller
                        control={control}
                        name={`metrics.${index}.dataSet`}
                        rules={{ required: true }}
                        render={({ field: { onChange, value, ref } }) => (
                          <TextField
                            label="Data sets"
                            select
                            sx={{ width: '100%' }}
                            onChange={onChange}
                            value={value}
                            ref={ref}
                          >
                            {filterDatasets(taskType, metrics[index]?.metricName, dataSetsDropDownValues).map(
                              (item: any, index: number) => (
                                <MenuItem value={item.value} key={`filter-by-role-${index}`}>
                                  <Flex>
                                    <Typography variant="subtitle2" sx={getTextOverflowEllipsisStyles(1)}>
                                      {i18n.t(item.name)}
                                    </Typography>
                                  </Flex>
                                </MenuItem>
                              )
                            )}
                          </TextField>
                        )}
                      />
                      {errors?.metrics?.[index]?.dataSet?.type === 'required' && (
                        <p style={{ color: COLOR_RED }}>Required!</p>
                      )}
                    </Box>

                    {fields.length > 1 && (
                      <Box>
                        <IconButton
                          id="removeModel"
                          onClick={() => remove(index)}
                          sx={{
                            background: 'blue',
                            color: COLOR_WHITE,
                            borderRadius: '8px',
                            ':hover': { background: COLOR_RED },
                          }}
                        >
                          <DeleteIcon />
                        </IconButton>
                      </Box>
                    )}
                  </Box>
                ))}

                <Box sx={{ marginTop: '1rem', justifyContent: 'right', display: 'flex' }}>
                  <Button
                    id="addModel"
                    onClick={() => {
                      const metricsValue = getValues('metrics');
                      if (
                        metricsValue[metricsValue.length - 1].dataSet === '' ||
                        metricsValue[metricsValue.length - 1].metricName === ''
                      ) {
                        trigger([`metrics`]);
                      } else {
                        append({ metricName: '', dataSet: '' });
                      }
                    }}
                    sx={{ width: 'auto' }}
                  >
                    Add Metric
                  </Button>
                </Box>
              </Box>
            </Box>
            <Box display="flex" flexDirection="column" gap={'1rem'} width="40%" paddingLeft="20px">
              <div className={classes.customSliderDiv}>
                <p className={classes.customSliderP}>Temperature</p>
                <Controller
                  control={control}
                  name={'temperature'}
                  rules={{ required: true }}
                  render={({ field: { onChange, value, ref } }) => (
                    <CustomSlider
                      min={0}
                      max={1}
                      step={0.1}
                      value={value}
                      ref={ref}
                      onChange={onChange}
                      valueLabelDisplay={'on'}
                    />
                  )}
                />
              </div>
              {errors.temperature && <p style={{ color: COLOR_RED }}>{errors.temperature.message}</p>}

              <div className={classes.customSliderDiv}>
                <p className={classes.customSliderP}>Max Tokens</p>
                <Controller
                  control={control}
                  name={'max_gen_len'}
                  rules={{ required: true }}
                  render={({ field: { onChange, value, ref } }) => (
                    <CustomSlider
                      min={0}
                      max={4096}
                      step={1}
                      value={value}
                      ref={ref}
                      onChange={onChange}
                      valueLabelDisplay={'on'}
                    />
                  )}
                />
              </div>
              {errors.max_gen_len && <p style={{ color: COLOR_RED }}>{errors.max_gen_len.message}</p>}
              <div className={classes.customSliderDiv}>
                <p className={classes.customSliderP}>Top_p</p>
                <Controller
                  control={control}
                  name={'top_p'}
                  rules={{ required: true }}
                  render={({ field: { onChange, value, ref } }) => (
                    <CustomSlider
                      min={0}
                      max={1}
                      step={0.1}
                      value={value}
                      ref={ref}
                      onChange={onChange}
                      valueLabelDisplay={'on'}
                    />
                  )}
                />
              </div>
              {errors.top_p && <p>{errors.top_p.message}</p>}
            </Box>
          </Box>
        </form>
      </Box>
      <Box
        style={{
          backgroundColor: COLOR_WHITE,
          padding: '20px 30px',
          borderRadius: '8px',
          border: '1px solid #E5E7EB',
          marginTop: '20px',
        }}
      >
        <Box display="flex" justifyContent="space-between" marginBottom="20px">
          <Typography variant="h4" sx={{ fontSize: '22px', fontWeight: 700 }}>
            {i18n.t('adminPortal.modelevaluation.jobs.heading')}
          </Typography>
          <IconButton onClick={() => handleRfresh()}>
            <RefreshIcon
              fontSize={'large'}
              sx={{ border: '1px solid ', borderRadius: '50%', padding: '2px', color: '#131F4D' }}
            ></RefreshIcon>
          </IconButton>
        </Box>
        {allEvaluationJob?.length > 0 ? (
          <Box>
            <TableContainer component={Paper}>
              <Table>
                <TableHead>
                  <TableRow>
                    <TableCell className={classes.tableCell}>Job Name</TableCell>
                    <TableCell className={classes.tableCell}>Status</TableCell>
                    <TableCell className={classes.tableCell}>Model</TableCell>
                    <TableCell className={classes.tableCell}>Evaluation type</TableCell>
                    <TableCell className={classes.tableCell}>Metric</TableCell>
                    <TableCell className={classes.tableCell}>Creted At</TableCell>
                    <TableCell className={classes.tableCell}>Actions</TableCell>
                  </TableRow>
                </TableHead>
                <TableBody>
                  {allEvaluationJob &&
                    allEvaluationJob?.slice((page - 1) * itemsPerPage, page * itemsPerPage).map((item: any, index) => (
                      <TableRow key={index}>
                        <TableCell
                          style={{
                            border: `1px solid ${GRAY_COLORS.GRAY_12}`,
                          }}
                        >
                          {item.jobName}
                        </TableCell>
                        <TableCell className={classes.tableCell}>{item.status}</TableCell>
                        <TableCell className={classes.tableCell}>
                          {modelDropDownValues.filter(model => model.value === item?.modelIdentifiers?.[0])[0]?.name}
                        </TableCell>
                        <TableCell className={classes.tableCell}>{item?.evaluationTaskTypes?.[0]}</TableCell>
                        <TableCell className={classes.tableCell}>
                          {item?.metrics?.map((metric: any, index: number) => (
                            <span>
                              {metric.metricName} {index === item.metrics.length - 1 ? '' : ', '}
                            </span>
                          ))}
                        </TableCell>
                        <TableCell className={classes.tableCell}>{item.creationTime}</TableCell>
                        <TableCell className={classes.tableCell}>
                          <div style={{ display: 'flex', justifyContent: 'space-between' }}>
                            <IconButton
                              size={'small'}
                              onClick={() => {
                                item.status === JobStatus.COMPLETED
                                  ? handleDownload(item)
                                  : console.log(JobStatus.INPROGRESS);
                              }}
                              sx={{
                                color: item.status === JobStatus.COMPLETED ? COLOR_ICON_PRIMARY : GRAY_COLORS.GRAY_6,
                              }}
                            >
                              <FileDownloadIcon />
                            </IconButton>
                            <IconButton
                              size={'small'}
                              onClick={() => {
                                item.status === JobStatus.COMPLETED
                                  ? handleShowOutput(item)
                                  : console.log(JobStatus.INPROGRESS);
                              }}
                              sx={{
                                color: item.status === JobStatus.COMPLETED ? COLOR_ICON_PRIMARY : GRAY_COLORS.GRAY_6,
                              }}
                            >
                              <VisibilityIcon />
                            </IconButton>
                          </div>
                        </TableCell>
                      </TableRow>
                    ))}
                </TableBody>
              </Table>
            </TableContainer>
            <Pagination
              sx={{ justifyContent: 'end', marginTop: '20px' }}
              count={Math.ceil(allEvaluationJob.length / itemsPerPage)}
              page={page}
              onChange={handleChangePage}
            />
          </Box>
        ) : (
          <Typography variant="h4" sx={{ fontSize: '22px', fontWeight: 700, textAlign: 'center', padding: '1rem' }}>
            {i18n.t('adminPortal.modelevaluation.jobs.noRecords')}
          </Typography>
        )}
      </Box>
      {showOutPutModal && (
        <ShowModelEvaluationResultsModal
          isOpen={showOutPutModal}
          onClose={() => {
            setShowOutPutModal(false);
          }}
          outputData={evaluationResultData}
          currentJobData={currentJobData}
        />
      )}
    </>
  );
};

export default observer(ModelEvaluationPage);
