"""Tests for engine/scoring.py.

Covers: operator dispatch, class index formula, certification level
boundaries, and critical non-conformance override gate.
"""
from __future__ import annotations

from dataclasses import dataclass, field
from types import SimpleNamespace

import pytest

from engine.scoring import (
    ScoringResult,
    compute_class_indices,
    dispatch_operator,
    evaluate_assessment,
    select_certification_level,
)


# ---------------------------------------------------------------------------
# Helpers — lightweight stubs that mimic ORM objects (no DB needed)
# ---------------------------------------------------------------------------

def _req(req_id, class_id, operator, threshold, mandatory=False, criticality=1):
    return SimpleNamespace(
        id=req_id,
        class_id=class_id,
        operator=operator,
        threshold=threshold,
        mandatory=mandatory,
        criticality=criticality,
    )


def _level(level, class_min, overall_min, class_scope="ALL", critical_threshold=3):
    return SimpleNamespace(
        level=level,
        class_min_index=class_min,
        overall_min_index=overall_min,
        class_scope=class_scope,
        critical_threshold=critical_threshold,
    )


# ---------------------------------------------------------------------------
# 4.1 — dispatch_operator
# ---------------------------------------------------------------------------

@pytest.mark.parametrize("operator,value,threshold,expected", [
    (">=", 80,   70,    True),
    (">=", 60,   70,    False),
    (">=", 70,   70,    True),   # exactly at boundary
    ("<=", 50,   70,    True),
    ("<=", 80,   70,    False),
    ("<=", 70,   70,    True),   # exactly at boundary
    ("=",  "5",  "5",   True),
    ("=",  "5",  "6",   False),
    ("yes-no", "yes", "yes", True),
    ("yes-no", "no",  "yes", False),
    ("yes-no", "YES", "yes", True),  # case-insensitive
    ("yes-no", "no",  "no",  True),
])
def test_operator_dispatch(operator, value, threshold, expected):
    assert dispatch_operator(operator, value, threshold) == expected


def test_operator_dispatch_unknown_raises():
    with pytest.raises(ValueError, match="Unknown operator"):
        dispatch_operator("!=", 1, 1)


# ---------------------------------------------------------------------------
# 4.2 — compute_class_indices (one formula per class I–IV)
# ---------------------------------------------------------------------------

def test_class_index_all_met():
    reqs = [
        _req(1, class_id=1, operator=">=", threshold="70"),
        _req(2, class_id=1, operator=">=", threshold="70"),
    ]
    results = {1: "80", 2: "90"}
    indices = compute_class_indices(reqs, results)
    assert indices[1] == pytest.approx(1.0)


def test_class_index_none_met():
    reqs = [
        _req(1, class_id=2, operator=">=", threshold="70"),
        _req(2, class_id=2, operator=">=", threshold="70"),
    ]
    results = {1: "50", 2: "60"}
    indices = compute_class_indices(reqs, results)
    assert indices[2] == pytest.approx(0.0)


def test_class_index_partial():
    reqs = [
        _req(1, class_id=3, operator=">=", threshold="70"),
        _req(2, class_id=3, operator=">=", threshold="70"),
        _req(3, class_id=3, operator=">=", threshold="70"),
        _req(4, class_id=3, operator=">=", threshold="70"),
    ]
    results = {1: "80", 2: "80", 3: "60", 4: "60"}  # 2 of 4 met
    indices = compute_class_indices(reqs, results)
    assert indices[3] == pytest.approx(0.5)


def test_class_index_missing_answer_treated_as_unmet():
    reqs = [
        _req(1, class_id=4, operator=">=", threshold="70"),
        _req(2, class_id=4, operator=">=", threshold="70"),
    ]
    results = {1: "80"}  # req 2 has no answer
    indices = compute_class_indices(reqs, results)
    assert indices[4] == pytest.approx(0.5)


def test_class_index_multiple_classes():
    reqs = [
        _req(1, class_id=1, operator=">=", threshold="70"),
        _req(2, class_id=2, operator="yes-no", threshold="yes"),
    ]
    results = {1: "80", 2: "yes"}
    indices = compute_class_indices(reqs, results)
    assert indices[1] == pytest.approx(1.0)
    assert indices[2] == pytest.approx(1.0)


# ---------------------------------------------------------------------------
# 4.3 — select_certification_level + critical override gate
# ---------------------------------------------------------------------------

LEVELS = [
    _level(3, class_min=0.95, overall_min=0.95, class_scope="ALL"),
    _level(2, class_min=0.80, overall_min=0.82, class_scope="ALL"),
    _level(1, class_min=0.75, overall_min=0.70, class_scope="I,II"),
]
CLASS_ID_BY_CODE = {"I": 1, "II": 2, "III": 3, "IV": 4}


@pytest.mark.parametrize("class_indices,overall,expected_level", [
    # Level 3: all classes >= 0.95, overall >= 0.95
    ({1: 0.95, 2: 0.95, 3: 0.95, 4: 0.95}, 0.95, 3),
    # Just below Level 3 on one class → falls to Level 2
    ({1: 0.94, 2: 0.95, 3: 0.95, 4: 0.95}, 0.95, 2),
    # Level 2: all classes >= 0.80, overall >= 0.82
    ({1: 0.80, 2: 0.80, 3: 0.80, 4: 0.80}, 0.82, 2),
    # Just below Level 2 on overall → falls to Level 1 (only I+II checked)
    ({1: 0.80, 2: 0.80, 3: 0.80, 4: 0.80}, 0.81, 1),
    # Level 1: classes I+II >= 0.75, overall >= 0.70
    ({1: 0.75, 2: 0.75, 3: 0.00, 4: 0.00}, 0.70, 1),
    # Just below Level 1 → not certified
    ({1: 0.74, 2: 0.75, 3: 0.00, 4: 0.00}, 0.70, None),
])
def test_certification_boundary(class_indices, overall, expected_level):
    result = select_certification_level(
        class_indices, overall, has_critical=False,
        levels=LEVELS, class_id_by_code=CLASS_ID_BY_CODE,
    )
    assert result == expected_level


def test_critical_override_caps_at_level1():
    # Even with perfect indices, critical non-conformance → Level 1
    result = select_certification_level(
        {1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}, overall_index=1.0,
        has_critical=True, levels=LEVELS, class_id_by_code=CLASS_ID_BY_CODE,
    )
    assert result == 1


def test_critical_override_when_would_be_level3():
    # Confirm override regardless of computed level
    result = select_certification_level(
        {1: 0.95, 2: 0.95, 3: 0.95, 4: 0.95}, overall_index=0.95,
        has_critical=True, levels=LEVELS, class_id_by_code=CLASS_ID_BY_CODE,
    )
    assert result == 1


# ---------------------------------------------------------------------------
# evaluate_assessment — integration (no DB, uses stubs)
# ---------------------------------------------------------------------------

def test_evaluate_assessment_full_pass():
    reqs = [
        _req(1, class_id=1, operator=">=", threshold="70", mandatory=True, criticality=1),
        _req(2, class_id=2, operator="yes-no", threshold="yes", mandatory=True, criticality=1),
        _req(3, class_id=3, operator=">=", threshold="50", mandatory=False, criticality=1),
        _req(4, class_id=4, operator=">=", threshold="50", mandatory=False, criticality=1),
    ]
    results = {1: "100", 2: "yes", 3: "100", 4: "100"}
    levels = [
        _level(3, 0.95, 0.95),
        _level(2, 0.80, 0.82),
        _level(1, 0.75, 0.70, class_scope="I,II"),
    ]
    r = evaluate_assessment(reqs, results, levels, CLASS_ID_BY_CODE)
    assert r.overall_index == pytest.approx(1.0)
    assert r.certification_level == 3
    assert r.has_critical is False
    assert r.nonconformity_req_ids == []


def test_evaluate_assessment_critical_non_conformance():
    reqs = [
        _req(1, class_id=1, operator=">=", threshold="70", mandatory=True, criticality=5),
        _req(2, class_id=2, operator=">=", threshold="70", mandatory=False, criticality=1),
    ]
    results = {1: "50", 2: "100"}  # req 1 fails with criticality=5 >= threshold=3
    levels = [
        _level(3, 0.95, 0.95),
        _level(2, 0.80, 0.82),
        _level(1, 0.75, 0.70, class_scope="I,II"),
    ]
    r = evaluate_assessment(reqs, results, levels, CLASS_ID_BY_CODE)
    assert r.has_critical is True
    assert r.certification_level == 1
    assert 1 in r.nonconformity_req_ids


def test_evaluate_assessment_not_certified():
    reqs = [
        _req(1, class_id=1, operator=">=", threshold="70", mandatory=True, criticality=1),
        _req(2, class_id=2, operator=">=", threshold="70", mandatory=True, criticality=1),
    ]
    results = {1: "60", 2: "60"}  # both fail
    levels = [
        _level(3, 0.95, 0.95),
        _level(2, 0.80, 0.82),
        _level(1, 0.75, 0.70, class_scope="I,II"),
    ]
    r = evaluate_assessment(reqs, results, levels, CLASS_ID_BY_CODE)
    assert r.certification_level is None
    assert r.has_critical is False
    assert set(r.nonconformity_req_ids) == {1, 2}
