import { chain, get, groupBy, isEqual, isNil, keys, mean } from 'lodash';
import { AssayInfo, L2Targets } from '../assays';
import { FullAbundance } from '../types';
import { safeAssert } from '../utils';

const AY1 = 'AY1';
const CT_LOQ = 28;

export function getRelative10FoldChange(
  beforeAbundancesInclAy1: FullAbundance[] | undefined,
  afterAbundancesInclAy1: FullAbundance[] | undefined,
  limitToTargetOrAssay: L2Targets | string | undefined,
  allAssays: AssayInfo[],
): number | null {
  const log2Value = getLog2FoldChange(beforeAbundancesInclAy1, afterAbundancesInclAy1, limitToTargetOrAssay, allAssays);
  return log2Value === null ? null : log2Value / Math.log2(10);
}

export function get10FoldChangeInVolume(
  beforeAbundancesInclAy1: FullAbundance[] | undefined,
  afterAbundancesInclAy1: FullAbundance[] | undefined,
  limitToTargetOrAssay: L2Targets | string | undefined,
  allAssays: AssayInfo[],
): number | null {
  // If there are no abundances at all, return null
  if (!beforeAbundancesInclAy1?.length || !afterAbundancesInclAy1?.length) {
    return null;
  }

  // Get copies per L for before and after samples
  const [before, detectedBefore] = getWorstCaseCopiesPerL(beforeAbundancesInclAy1, limitToTargetOrAssay, allAssays);
  const [after, detectedAfter] = getWorstCaseCopiesPerL(afterAbundancesInclAy1, limitToTargetOrAssay, allAssays);

  if ((!detectedBefore && !detectedAfter) || (isNil(before) && isNil(after))) {
    // The assays in question were not detected in either set of samples, or normalization metadata isn't available
    return null;
  } else if (!isNil(before) && !isNil(after)) {
    // At least some of the assays in question were detected in before or after samples (or both),
    // and both before and after samples had AY1 detected, allowing us to estimate the normalization factor:
    // Return a reduction value that assumes the worst: that LOD copies would exist in all undetected assays.
    return Math.min(5, Math.max(-5, Math.log10(after / before)));
  } else if (isNil(after)) {
    // Some of the genes in question were detected in before samples, but even AY1 was not detected in after samples
    return -5;
  } else {
    // Some of the genes in question were detected in after samples, but even AY1 was not detected in before samples
    return 5;
  }
}

function getWorstCaseCopiesPerL(
  abundances: FullAbundance[] | undefined,
  limitToTargetOrAssay: L2Targets | string | undefined,
  allAssays: AssayInfo[],
) {
  if (!abundances?.length) {
    return [null, false] as const;
  }

  // Use AY1 abundance to calculate fallback normalization factor
  // For missing values, we assume LOD 100 for TEN_UL_DILUTED_DNA (same as in heatmap)
  // and normalise it to LOD in LITREwith volume normalization factor.
  // Having not stored this factor in db, we estimate it from AY1 values in each before and after samples.
  // Note that these could be aggregated samples that may have used different filtered volumes.
  const ay1 = abundances.find(a => a.assay === AY1);
  const normalisationFactor = ay1?.absolute && ay1.copiesPerL ? ay1.copiesPerL / ay1.absolute : null;
  if (!normalisationFactor) {
    return [null, false] as const;
  }

  const LOD_TEN_UL = 100;
  const LOD_COPY_NUMBER = LOD_TEN_UL * normalisationFactor;

  const inScopeAssays: Set<string> | undefined =
    // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition, @typescript-eslint/no-confusing-void-expression
    limitToTargetOrAssay && get(L2Targets, limitToTargetOrAssay, undefined)
      ? new Set(allAssays.filter(a => a.l2Target === (limitToTargetOrAssay as L2Targets)).map(a => a.assay))
      : limitToTargetOrAssay
        ? new Set([limitToTargetOrAssay])
        : undefined;

  const inSccopeAbundances = abundances.filter(a => a.assay !== AY1 && (!inScopeAssays || inScopeAssays.has(a.assay)));
  if (inSccopeAbundances.some(a => !isNil(a.meanCt) && isNil(a.copiesPerL))) {
    console.error(
      'UNEXPECTEDLY MISSING COPIES PER L FOR ABUNDACE',
      inSccopeAbundances.find(a => !isNil(a.meanCt) && isNil(a.copiesPerL)),
    );
    return [null, false] as const;
  }

  // For each gene, use its copiesPerL value if available, otherwise use LOD
  const detected = inSccopeAbundances.some(a => !isNil(a.copiesPerL));
  const reduction =
    chain(inSccopeAbundances)
      .map(a => a.copiesPerL || LOD_COPY_NUMBER)
      .sum()
      .value() || null;
  return [reduction, detected] as const;
}

export function getFoldChangeRatio(
  beforeAbundancesInclAy1: FullAbundance[] | undefined,
  afterAbundancesInclAy1: FullAbundance[] | undefined,
  limitToTargetOrAssay: L2Targets | string | undefined,
  allAssays: AssayInfo[],
): number | null {
  const log2FoldChange = getLog2FoldChange(
    beforeAbundancesInclAy1,
    afterAbundancesInclAy1,
    limitToTargetOrAssay,
    allAssays,
  );
  return isNil(log2FoldChange) ? null : Math.pow(2, log2FoldChange);
}

export function getLog2FoldChange(
  allBeforeAbundancesInclAy1: FullAbundance[] | undefined,
  allAfterAbundancesInclAy1: FullAbundance[] | undefined,
  limitToTargetOrAssay: L2Targets | string | undefined,
  allAssays: AssayInfo[],
): number | null {
  if (!allAfterAbundancesInclAy1?.length || !allBeforeAbundancesInclAy1?.length) {
    return null;
  }

  const inScopeAssays: Set<string> | undefined =
    // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition, @typescript-eslint/no-confusing-void-expression
    limitToTargetOrAssay && get(L2Targets, limitToTargetOrAssay, undefined)
      ? new Set(allAssays.filter(a => a.l2Target === (limitToTargetOrAssay as L2Targets)).map(a => a.assay))
      : limitToTargetOrAssay
        ? new Set([limitToTargetOrAssay])
        : undefined;

  const beforeAbundancesInclAy1 = inScopeAssays
    ? allBeforeAbundancesInclAy1.filter(a => a.assay === AY1 || inScopeAssays.has(a.assay))
    : allBeforeAbundancesInclAy1;
  const afterAbundancesInclAy1 = inScopeAssays
    ? allAfterAbundancesInclAy1.filter(a => a.assay === AY1 || inScopeAssays.has(a.assay))
    : allAfterAbundancesInclAy1;

  return getLog2FoldChangeOrThrow(beforeAbundancesInclAy1, afterAbundancesInclAy1);
}

/**
 * This function calculates the overall fold change metric for a set of genes by comparing after-sample Ct values
 * relative to before-sample Ct values. The aggregation across genes and biological replicates is done using the
 * geometric mean, which is suitable for logarithmic data.
 *
 * Note 1: The order of Ct values within biological replicates does not affect the overall fold change calculation.
 * Note 2: We assume CL_LOQ for null ct values, taking into count sample pairs where a gene is not detected in one of the samples and assuming the "minimum" fold change: ie. that the missing ene count is just below LOQ
 */
export function getLog2FoldChangeOrThrow(
  beforeAbundancesInclAy1: FullAbundance[],
  afterAbundancesInclAy1: FullAbundance[],
): number | null {
  safeAssert(
    Boolean(beforeAbundancesInclAy1.length) && Boolean(afterAbundancesInclAy1.length),
    'Abundances and afterAbundances must have a length',
  );

  safeAssert(
    beforeAbundancesInclAy1.length === afterAbundancesInclAy1.length,
    'Abundances and afterAbundances must have the same length',
  );

  // Compute average meanCt for the reference assay AY1 before
  const ay1Before = beforeAbundancesInclAy1.filter(a => a.assay === AY1);
  safeAssert(ay1Before.length > 0, 'Reference assay (AY1) meanCt not found in abundances.');
  safeAssert(
    ay1Before.every(a => a.meanCt !== null),
    'Reference assay (AY1) meanCt must not be null in influent samples',
  );

  // Compute average meanCt for the reference assay AY1 after
  const ay1After = afterAbundancesInclAy1.filter(a => a.assay === AY1);
  safeAssert(ay1After.length > 0, 'Reference assay (AY1) meanCt not found in afterAbundances.');
  safeAssert(
    ay1After.every(a => a.meanCt !== null),
    'Reference assay (AY1) meanCt must not be null in effluent samples',
  );
  safeAssert(
    ay1After.length === ay1Before.length,
    'Reference assay (AY1) abundances must have the same length in influent and effluent samples',
  );

  // Group abundances by assay and compute average meanCt for each assay
  const beforeByAssay = groupBy(
    beforeAbundancesInclAy1.filter(a => a.assay !== AY1),
    a => a.assay,
  );
  const afterByAssay = groupBy(
    afterAbundancesInclAy1.filter(a => a.assay !== AY1),
    a => a.assay,
  );

  const assays = keys(beforeByAssay);
  const afterAssays = keys(afterByAssay);
  safeAssert(isEqual(new Set(assays), new Set(afterAssays)), 'Before and after assays must be identical');

  const ddCts = calculateDdCts(beforeByAssay, afterByAssay, ay1Before, ay1After);

  const geometricMeanInLogSpace = mean(ddCts);

  return -geometricMeanInLogSpace;
}

export function calculateDdCts(
  beforeByAssay: Record<string, FullAbundance[]>,
  afterByAssay: Record<string, FullAbundance[]>,
  ay1Before: FullAbundance[],
  ay1After: FullAbundance[],
): number[] {
  const assays = keys(beforeByAssay);
  return assays.flatMap(assay => {
    const beforeAbundances = beforeByAssay[assay];
    const afterAbundances = afterByAssay[assay];
    safeAssert(beforeAbundances.length === afterAbundances.length, `More ${assay}s before or after`);
    safeAssert(beforeAbundances.length === ay1Before.length, `More ${assay}s before or after`);

    // Note: if a gene appears from nowhere during treatment, take this into account (don't just skip them) and assume "minimum" fold change: ie. that the gene count was just below LOQ before treatment
    const dCtsBefore = beforeAbundances.map((a, i) => (a.meanCt || CT_LOQ) - (ay1Before[i].meanCt as number));

    // Note: if a gene goes below LOQ in the treatment, take this into account (don't just skip it) and conservatively assume minimum fold change: ie. that it fell just below LOQ in the treatment
    const dCtsAfter = afterAbundances.map((a, i) => (a.meanCt || CT_LOQ) - (ay1After[i].meanCt as number));

    const ddCtsForAssay = dCtsAfter.map((dCtAfter, i) => dCtAfter - dCtsBefore[i]);
    return ddCtsForAssay;
  });
}
