Source code for finamt.agents.pipeline

"""
finamt.agents.pipeline
~~~~~~~~~~~~~~~~~~~~~~~~~~
4-agent sequential extraction pipeline.

  Agent 1  →  receipt number, date, category
  Agent 2  →  counterparty (vendor or client depending on receipt_type)
  Agent 3  →  amounts (total, vat_percentage, vat_amount)
  Agent 4  →  line items

Each agent runs sequentially (not parallel) for compatibility with local models.
After all 4 finish, results are merged in Python (no LLM validator step).

Debug output saved to ~/.finamt/debug/<receipt_id>/:
  agent1_prompt.txt / agent1_raw.txt / agent1_parsed.json
  agent2_prompt.txt / agent2_raw.txt / agent2_parsed.json
  agent3_prompt.txt / agent3_raw.txt / agent3_parsed.json
  agent4_prompt.txt / agent4_raw.txt / agent4_parsed.json
  final.json
"""

from __future__ import annotations

import json
import time
from pathlib import Path

from finamt import progress as _progress

from ..models import (
    Address,
    Counterparty,
    ReceiptCategory,
    ReceiptData,
    ReceiptItem,
    ReceiptType,
)
from ..utils import parse_date, parse_decimal
from .config import AgentsConfig
from .llm_caller import call_llm
from .prompts import (
    build_agent1_prompt,
    build_agent2_prompt,
    build_agent3_prompt,
    build_agent4_prompt,
)

_DEFAULT_DEBUG_ROOT = Path.home() / ".finamt" / "debug"


def _ts() -> str:
    return time.strftime("[%H:%M:%S]")


# ---------------------------------------------------------------------------
# Field validators — null-safe type coercions
# ---------------------------------------------------------------------------


def _str_or_none(v) -> str | None:
    if v is None:
        return None
    s = str(v).strip()
    # Reject obvious field labels (end with ":" or are pure numbers/short garbage)
    if s.endswith(":") or len(s) < 2:
        return None
    return s or None


def _float_or_none(v) -> float | None:
    if v is None:
        return None
    try:
        return float(str(v).replace(",", "."))
    except (ValueError, TypeError):
        return None


def _validate_agent1(raw: dict | None) -> dict:
    if not raw:
        return {}
    result: dict = {}
    if rn := _str_or_none(raw.get("receipt_number")):
        result["receipt_number"] = rn
    if rd := raw.get("receipt_date"):
        parsed = parse_date(str(rd))
        if parsed:
            result["receipt_date"] = parsed
    cat = raw.get("category", "other")
    try:
        result["category"] = ReceiptCategory(cat)
    except ValueError:
        result["category"] = ReceiptCategory("other")
    return result


def _validate_agent2(raw: dict | None) -> dict:
    if not raw:
        return {}
    result: dict = {}
    for key in (
        "name",
        "vat_id",
        "tax_number",
        "street_and_number",
        "address_supplement",
        "postcode",
        "city",
        "state",
        "country",
    ):
        if v := _str_or_none(raw.get(key)):
            result[key] = v
    return result


def _validate_agent3(raw: dict | None) -> dict:
    if not raw:
        return {}
    result: dict = {}
    total = _float_or_none(raw.get("total_amount"))
    vat_pct = _float_or_none(raw.get("vat_percentage"))
    vat_amt = _float_or_none(raw.get("vat_amount"))

    if total is not None and total > 0:
        result["total_amount"] = total
    if vat_pct is not None and 0 <= vat_pct <= 100:
        result["vat_percentage"] = vat_pct
    if vat_amt is not None and vat_amt >= 0:
        # Sanity: vat_amount must be less than total if total known
        if total is None or vat_amt < total:
            result["vat_amount"] = vat_amt

    # Currency — accept only plausible ISO 4217 codes (2–4 uppercase letters)
    raw_cur = str(raw.get("currency") or "").strip().upper()
    result["currency"] = raw_cur if (2 <= len(raw_cur) <= 4 and raw_cur.isalpha()) else "EUR"

    return result


def _validate_agent4(raw: dict | None) -> list:
    if not raw:
        return []
    items = raw.get("items") or []
    result = []
    for item in items:
        if not isinstance(item, dict):
            continue
        desc = _str_or_none(item.get("description"))
        total = _float_or_none(item.get("total_price"))
        vat_rate = _float_or_none(item.get("vat_rate"))
        vat_amt = _float_or_none(item.get("vat_amount"))
        # Skip completely empty rows
        if not desc and total is None:
            continue
        result.append(
            {
                "description": desc,
                "total_price": total,
                "vat_rate": vat_rate if (vat_rate is not None and 0 <= vat_rate <= 100) else None,
                "vat_amount": vat_amt,
            }
        )
    return result


# ---------------------------------------------------------------------------
# Model builder
# ---------------------------------------------------------------------------


def _build_receipt_data(
    meta: dict,
    counterparty: dict,
    amounts: dict,
    items: list,
    raw_text: str,
    receipt_type: str,
) -> ReceiptData:

    # Counterparty
    cp: Counterparty | None = None
    if counterparty:
        address = Address(
            street_and_number=counterparty.get("street_and_number"),
            address_supplement=counterparty.get("address_supplement"),
            postcode=counterparty.get("postcode"),
            city=counterparty.get("city"),
            state=counterparty.get("state"),
            country=counterparty.get("country"),
        )
        cp = Counterparty(
            name=counterparty.get("name"),
            vat_id=counterparty.get("vat_id"),
            tax_number=counterparty.get("tax_number"),
            address=address,
        )

    # Line items
    receipt_items: list[ReceiptItem] = []
    for idx, item in enumerate(items, start=1):
        try:
            receipt_items.append(
                ReceiptItem(
                    description=item.get("description") or "",
                    quantity=None,
                    unit_price=None,
                    total_price=parse_decimal(item.get("total_price")),
                    vat_rate=parse_decimal(item.get("vat_rate")),
                    vat_amount=parse_decimal(item.get("vat_amount")),
                    category=ReceiptCategory("other"),
                    position=idx,
                )
            )
        except Exception:
            pass

    return ReceiptData(
        raw_text=raw_text,
        receipt_type=ReceiptType(receipt_type),
        counterparty=cp,
        receipt_number=meta.get("receipt_number"),
        receipt_date=meta.get("receipt_date"),
        total_amount=parse_decimal(amounts.get("total_amount")),
        vat_percentage=parse_decimal(amounts.get("vat_percentage")),
        vat_amount=parse_decimal(amounts.get("vat_amount")),
        currency=amounts.get("currency", "EUR"),
        category=meta.get("category", ReceiptCategory("other")),
        items=receipt_items,
    )


# ---------------------------------------------------------------------------
# Taxpayer-info cleanup
# ---------------------------------------------------------------------------


def _strip_taxpayer_fields(counterparty: dict, taxpayer_info: dict | None) -> dict:
    """Silently null out counterparty fields that are exact copies of the taxpayer's
    own data (case-insensitive, whitespace-normalised).  No warnings are emitted.

    This guards against the LLM defaulting to the taxpayer's VAT-ID / tax-number
    / name when no real counterparty data is present in the document.
    """
    if not taxpayer_info:
        return counterparty

    def _norm(v: str | None) -> str:
        return (v or "").strip().casefold()

    checks = {
        "name": taxpayer_info.get("name"),
        "vat_id": taxpayer_info.get("vat_id"),
        "tax_number": taxpayer_info.get("tax_number"),
        # Address sub-fields — mapped from the individual taxpayer profile fields
        "street_and_number": taxpayer_info.get("street"),
        "postcode": taxpayer_info.get("postcode"),
        "city": taxpayer_info.get("city"),
        "state": taxpayer_info.get("state"),
        "country": taxpayer_info.get("country"),
    }

    result = dict(counterparty)
    for field, taxpayer_value in checks.items():
        if taxpayer_value and _norm(result.get(field)) == _norm(taxpayer_value):
            result[field] = None

    return result


# ---------------------------------------------------------------------------
# Pipeline
# ---------------------------------------------------------------------------


[docs] def run_pipeline( raw_text: str, pdf_path: str | Path | None, # kept for API compat, not used receipt_type: str, cfg: AgentsConfig | None = None, receipt_id: str | None = None, debug_root: Path | None = _DEFAULT_DEBUG_ROOT, taxpayer_info: dict | None = None, ) -> ReceiptData: if cfg is None: cfg = AgentsConfig() agent_cfg = cfg.get_agent_config() debug_dir: Path | None = None if debug_root is not None and receipt_id: debug_dir = Path(debug_root) / receipt_id debug_dir.mkdir(parents=True, exist_ok=True) # ── Agent 1: metadata ────────────────────────────────────────────────── _progress.emit(f" {_ts()} → Agent 1: metadata") raw1 = call_llm( prompt=build_agent1_prompt(raw_text), cfg=agent_cfg, agent_name="agent1", expected_keys=["receipt_number", "receipt_date", "category"], debug_dir=debug_dir, ) meta = _validate_agent1(raw1) # ── Agent 2: counterparty ────────────────────────────────────────────── _progress.emit(f" {_ts()} → Agent 2: counterparty") raw2 = call_llm( prompt=build_agent2_prompt(raw_text, receipt_type, taxpayer_info), cfg=agent_cfg, agent_name="agent2", expected_keys=[ "name", "vat_id", "tax_number", "street_and_number", "address_supplement", "postcode", "city", "state", "country", ], debug_dir=debug_dir, ) counterparty = _validate_agent2(raw2) counterparty = _strip_taxpayer_fields(counterparty, taxpayer_info) # ── Agent 3: amounts ─────────────────────────────────────────────────── _progress.emit(f" {_ts()} → Agent 3: amounts") raw3 = call_llm( prompt=build_agent3_prompt(raw_text), cfg=agent_cfg, agent_name="agent3", expected_keys=["total_amount", "vat_percentage", "vat_amount", "currency"], debug_dir=debug_dir, ) amounts = _validate_agent3(raw3) # ── Agent 4: line items ──────────────────────────────────────────────── _progress.emit(f" {_ts()} → Agent 4: line items") raw4 = call_llm( prompt=build_agent4_prompt(raw_text), cfg=agent_cfg, agent_name="agent4", expected_keys=["items"], debug_dir=debug_dir, ) items = _validate_agent4(raw4) # ── Debug: save final merge ──────────────────────────────────────────── if debug_dir is not None: final_debug = { "meta": {**meta, "receipt_date": str(meta.get("receipt_date", ""))}, "counterparty": counterparty, "amounts": amounts, "items": items, } (debug_dir / "final.json").write_text( json.dumps(final_debug, indent=2, ensure_ascii=False, default=str), encoding="utf-8", ) return _build_receipt_data(meta, counterparty, amounts, items, raw_text, receipt_type)