import { any, filter, pipe, prop } from 'lodash/fp'

import { DeltaValue, InlineMessage } from '@cmpkit/base'
import AlertIcon from '@cmpkit/icon/lib/glyph/alert'

import LabeledValue from '@/components/LabeledValue'
import {
	FilterRuleModel,
	MetricModel,
	MetricsDataUnit,
	OptimizationIDModel,
} from '@/generated'
import intl from '@/locale'
import { useMetricsQuery } from '@/modules/bi/queries'
import { useInterpretabilityStatisticQuery } from '@/modules/core/queries'

import {
	formatMericValue,
	formatTotalDiffValue,
	getMetricDiffByType,
	getTypesByMetricSchema,
	getUnprefixedMetricKey,
} from './helpers'

export const MetricsSummaryWidget = ({
	isLoading: isDataLoading,
	queryParams,
}: {
	isLoading: boolean
	queryParams: {
		optimizations: OptimizationIDModel[]
		filters: FilterRuleModel[]
	} | null
}) => {
	const metricsQuery = useMetricsQuery<MetricModel[]>({
		select: pipe([filter(({ type }) => type === 'predict_default')]),
	})
	const interpretabilityStatisticQuery = useInterpretabilityStatisticQuery(
		queryParams!,
		{
			refetchOnMount: true,
			staleTime: 0,
			enabled: !!queryParams && !!metricsQuery.data && !isDataLoading,
			select: (data) => {
				return METRICS.map((metric) => {
					const value = extractMetricValues(
						data?.metrics as MetricsDataStatistic,
						metric
					)
					return getMetricValues({
						metric,
						data: value as unknown as MetricsDataUnit,
						metricsSchema: metricsQuery.data as MetricModel[],
					})
				})
			},
		}
	)
	const isLoading =
		any(prop('isLoading'), [metricsQuery, interpretabilityStatisticQuery]) ||
		isDataLoading
	const isError = any(prop('isError'), [
		metricsQuery,
		interpretabilityStatisticQuery,
	])
	const error = metricsQuery.error || interpretabilityStatisticQuery.error
	if (isLoading) {
		return (
			<div className='mt-2 flex w-full space-x-4'>
				<div className='h-10 w-24 animate-pulse rounded bg-accent-4' />
				<div className='h-10 w-20 animate-pulse rounded bg-accent-4' />
				<div className='h-10 w-20 animate-pulse rounded bg-accent-4' />
			</div>
		)
	}
	if (isError) {
		return (
			<InlineMessage
				icon={<AlertIcon />}
				variant='danger'
				className='mr-5 text-xs'
			>
				{error?.message || intl.get('fatal_error_title')}
			</InlineMessage>
		)
	}

	return (
		<div className='flex space-x-4'>
			{!!interpretabilityStatisticQuery.data &&
				interpretabilityStatisticQuery.data.map((item) => {
					return (
						<LabeledValue
							label={`${intl.get(`metric.${item.key}`).d(item.key)}`}
						>
							<span className='text-lg font-semibold'>{item.value}</span>
							<DeltaValue className='ml-1 text-xs' value={item.diffValue}>
								{formatTotalDiffValue(Number(item.diffValue), item.diffType)}
							</DeltaValue>
						</LabeledValue>
					)
				})}
		</div>
	)
}
type MetricsDataStatistic = {
	[key: string]: MetricsDataUnit
}

const METRICS = [
	'predict_revenue',
	'predict_gross_profit_margin',
	'predict_sales_items',
]
const getMetricValues = ({
	metric,
	data,
	metricsSchema,
}: {
	metric: string
	data: MetricsDataUnit
	metricsSchema: MetricModel[]
}) => {
	const metricSchema = metricsSchema?.find(
		({ name }) => name.split('.')[0] === `${metric}`
	)
	const { dataType, diffType } = getTypesByMetricSchema(metricSchema!)
	const diffValue = getMetricDiffByType(data, diffType)
	const key = getUnprefixedMetricKey(metric)
	return {
		key,
		value: formatMericValue({
			value: data.final || 0,
			type: dataType,
		}),
		diffType,
		diffValue,
	}
}
const extractMetricValues = (data: MetricsDataStatistic, metric: string) => {
	return {
		init: data?.[`${metric}.init`],
		final: data?.[`${metric}.final`],
		optimal: data?.[`${metric}.optimal`],
	}
}
