import { useAssayContext } from '@resistapp/client/contexts/assay-context';
import { useSampleDataContext } from '@resistapp/client/contexts/sample-data-context';
import { useOverviewContext } from '@resistapp/client/contexts/use-overview-context/use-overview-context';
import type { OverviewDatum } from '@resistapp/client/data-utils/plot-data/build-overview-line-data';
import { getMetricAndLevel } from '@resistapp/client/utils/metric-utils';
import { L2Target } from '@resistapp/common/assays';
import type { ResistanceLevel } from '@resistapp/common/statistics/resistance-index';
import { ProcessMode } from '@resistapp/common/types';
import { useMemo } from 'react';

interface UseMetricAndLevelOptions {
  processMode?: ProcessMode;
  selectedTargets?: L2Target[];
}

interface CacheEntry {
  metric: number | null;
  level: ResistanceLevel | null;
}

// Create separate caches for different combinations of non-datum dependencies
const cachesByKey = new Map<string, WeakMap<OverviewDatum, CacheEntry>>();

export function useMetricAndLevel(datum: OverviewDatum | undefined, options: UseMetricAndLevelOptions = {}) {
  const { queryFilters } = useSampleDataContext();
  const { getGroup, allAssays } = useAssayContext();
  const { activeChartUnit, metricMode, processMode: contextProcessMode } = useOverviewContext();

  const effectiveProcessMode = options.processMode ?? contextProcessMode;
  const effectiveSelectedTargets = options.selectedTargets ?? queryFilters.filters.selectedTargets;

  return useMemo(() => {
    if (!datum) {
      return { metric: null, level: null };
    }

    // We cache in 2 levels:
    // 1. We cache in a WeakMap for each combination of non-datum dependencies
    // 2. We cache in a Map for each datum in the WeakMap
    // We do this so that we can use the datum object reference as cache key
    // We don't need to cache getGroup and allAssays, because they are retrieved only once from the server
    const cacheKey = `${effectiveProcessMode}-${metricMode}-${activeChartUnit}-${effectiveSelectedTargets.join(',')}`;

    // Get or create WeakMap for this combination of dependencies
    let datumCache = cachesByKey.get(cacheKey);
    if (!datumCache) {
      datumCache = new WeakMap();
      cachesByKey.set(cacheKey, datumCache);
    }

    const cached = datumCache.get(datum);
    if (cached) {
      return cached;
    }

    const [metric, level] = getMetricAndLevel(
      datum,
      effectiveSelectedTargets,
      metricMode,
      effectiveProcessMode,
      activeChartUnit,
      getGroup,
      allAssays,
    );

    const result = { metric, level };
    datumCache.set(datum, result);
    return result;
  }, [datum, effectiveSelectedTargets, metricMode, effectiveProcessMode, activeChartUnit, getGroup, allAssays]);
}
