Skip to content

Python API Reference

scan

doc_firewall.scan(file_path, config=None)

Source code in src/doc_firewall/scanner.py
def scan(file_path: str, config: Optional[ScanConfig] = None) -> ScanReport:
    return Scanner(config=config).scan(file_path)

Scanner

doc_firewall.Scanner

Source code in src/doc_firewall/scanner.py
class Scanner:
    def __init__(self, config: Optional[ScanConfig] = None):
        self.config = config or ScanConfig()
        self.risk_model = RiskModel(self.config)
        self._executor = ThreadPoolExecutor(
            max_workers=getattr(self.config, "max_workers", 4)
        )
        # Initialize detectors
        self.detectors = [
            EmbeddedPayloadDetector(),
            PdfDoSDetector(),  # Deep scan for DoS
            MetadataInjectionDetector(),
            ATSManipulationDetector(),
            PromptInjectionDetector(),
            RankingManipulationDetector(),
            YaraDetector(),
            TextObfuscationDetector(),
            HiddenTextDetector(),
            AdvancedPromptInjectionDetector(),
            AdvancedATSNLPDetector(),
            CredentialLeakageDetector(),
        ]

    async def scan_async(self, file_path: str) -> ScanReport:
        file_path = os.path.abspath(file_path)

        # Security: Validate path resolves to a regular file
        if not os.path.isfile(file_path):
            raise FileNotFoundError(f"Not a regular file: {file_path}")
        real_path = os.path.realpath(file_path)
        if real_path != file_path and not os.path.isfile(real_path):
            raise ValueError("Symbolic link target does not exist")

        # Basic File info
        try:
            size_bytes = os.path.getsize(file_path)
            # Guard against OOM: reject excessively large files before hashing
            hard_limit = self.config.limits.max_mb * 1024 * 1024 * 2
            if size_bytes > hard_limit:
                raise ValueError(
                    f"File size ({size_bytes} bytes) exceeds hashing limit"
                )

            sha = sha256_file(file_path)

            # Determine file type by extension, then verify with magic bytes
            ftype = guess_file_type(file_path)
            magic_type = _detect_file_type_by_magic(file_path)
            if ftype != "unknown" and magic_type != "unknown" and ftype != magic_type:
                logger.warning(
                    "Extension/magic-byte mismatch",
                    extension_type=ftype,
                    magic_type=magic_type,
                )
                ftype = magic_type  # Trust magic bytes over extension
            elif ftype == "unknown" and magic_type != "unknown":
                ftype = magic_type

        except Exception as e:
            logger.error("Pre-flight check failed", file=file_path, error=str(e))
            raise

        log_ctx = logger.bind(file_path=file_path, sha256=sha, file_type=ftype)
        log_ctx.info("Starting scan")

        report = ScanReport(
            file_path=file_path, file_type=ftype, sha256=sha, size_bytes=size_bytes
        )

        # --- STAGE 1: FAST SCAN ---
        size_mb = size_bytes / (1024 * 1024)
        if size_mb > self.config.limits.max_mb:
            log_ctx.warning("File size exceeded", size_mb=size_mb)
            report.add(
                Finding(
                    threat_id=ThreatID.T6_DOS,
                    severity=Severity.HIGH,
                    title="File exceeds size limit",
                    explain=(
                        f"File is {size_mb:.2f} MB, "
                        f"limit is {self.config.limits.max_mb} MB."
                    ),
                    evidence={
                        "size_mb": size_mb,
                        "limit_mb": self.config.limits.max_mb,
                    },
                    module="preflight",
                )
            )
            report.risk_score = self.risk_model.calculate_risk(report.findings)
            report.verdict = self.risk_model.get_verdict(report.risk_score)
            return report  # Early exit

        fast_findings = []
        loop = asyncio.get_running_loop()

        with Timer() as t:

            def _run_fast_scan():
                findings = []
                # 1. Embedded Payload Fast Scan
                if self.config.enable_embedded_content_checks:
                    findings.extend(
                        EmbeddedPayloadDetector.fast_scan(file_path, self.config)
                    )

                # 2. Existing Fast Scans
                if "pdf" in ftype and self.config.enable_pdf:
                    findings.extend(fast_scan_pdf(file_path, self.config))
                elif ftype == "docx" and self.config.enable_docx:
                    findings.extend(fast_scan_docx(file_path, self.config))
                elif ftype == "pptx" and self.config.enable_pptx:
                    findings.extend(fast_scan_pptx(file_path, self.config))
                elif ftype == "xlsx" and self.config.enable_xlsx:
                    findings.extend(fast_scan_xlsx(file_path, self.config))

                # 3. New DoS Fast Checks
                if "pdf" in ftype and self.config.enable_pdf:
                    findings.extend(PdfDoSDetector.fast_scan(file_path, self.config))

                return findings

            try:
                fast_findings = await loop.run_in_executor(
                    self._executor, _run_fast_scan
                )
            except Exception as e:
                log_ctx.error("Fast scan error", error=str(e))

        report.timings_ms["fast_scan"] = t.duration_ms
        report.findings.extend(fast_findings)

        # Gating Logic
        fast_score = self.risk_model.calculate_risk(report.findings)

        # If Critical -> Stop
        if any(f.severity == Severity.CRITICAL for f in fast_findings):
            log_ctx.info("Critical fast finding, aborting deep scan")
            report.risk_score = fast_score
            report.verdict = self.risk_model.get_verdict(report.risk_score)
            return report

        # Determine Deep Scan
        should_deep_scan = False
        if fast_score >= self.config.thresholds.deep_scan_trigger:
            should_deep_scan = True
        elif ftype == "unknown" and size_mb < self.config.limits.max_mb:
            should_deep_scan = True
        elif (
            (ftype == "pdf" and self.config.enable_pdf)
            or (ftype == "docx" and self.config.enable_docx)
            or (ftype == "pptx" and self.config.enable_pptx)
            or (ftype == "xlsx" and self.config.enable_xlsx)
        ):
            should_deep_scan = True

        if not should_deep_scan:
            log_ctx.info("Skipping deep scan (score below threshold)", score=fast_score)
            report.risk_score = fast_score
            report.verdict = self.risk_model.get_verdict(report.risk_score)
            return report

        # --- STAGE 2: DEEP SCAN ---
        parsed_doc: Optional[ParsedDocument] = None

        # 2a. Parsing
        with Timer() as t:
            try:

                def _parse_task():
                    if ftype == "pdf" and self.config.enable_pdf:
                        return parse_pdf(file_path, self.config)
                    elif ftype == "docx" and self.config.enable_docx:
                        return parse_docx(file_path, self.config)
                    elif ftype == "pptx" and self.config.enable_pptx:
                        return parse_pptx(file_path, self.config)
                    elif ftype == "xlsx" and self.config.enable_xlsx:
                        return parse_xlsx(file_path, self.config)
                    return ParsedDocument(
                        file_path=file_path, file_type=ftype, text="", metadata={}
                    )

                parsed_doc = await asyncio.wait_for(
                    loop.run_in_executor(self._executor, _parse_task),
                    timeout=self.config.limits.parse_timeout_ms / 1000.0,
                )
            except asyncio.TimeoutError:
                log_ctx.error("Parsing timed out")
                report.add(
                    Finding(
                        threat_id=ThreatID.T6_DOS,
                        severity=Severity.HIGH,
                        title="Parsing timed out",
                        explain="Document parsing exceeded time limit.",
                        module="stage.parse",
                    )
                )
            except Exception as e:
                log_ctx.error("Parsing failed", error=str(e))
                report.add(
                    Finding(
                        threat_id=ThreatID.T6_DOS,
                        severity=Severity.MEDIUM,
                        title="Parsing failed",
                        explain=f"Document parsing error: {type(e).__name__}",
                        module="stage.parse",
                    )
                )
        report.timings_ms["parse"] = t.duration_ms

        if parsed_doc:
            # 2b. Format Checks (Active Content / Obfuscation)
            with Timer() as t:
                try:

                    def _format_checks_task():
                        fs = []
                        if self.config.enable_active_content_checks:
                            if parsed_doc.file_type == "pdf":
                                fs.extend(
                                    detect_pdf_active_content(parsed_doc, self.config)
                                )
                            elif parsed_doc.file_type == "docx":
                                fs.extend(
                                    detect_docx_external_refs(parsed_doc, self.config)
                                )
                                fs.extend(
                                    detect_docx_ole_objects(parsed_doc, self.config)
                                )
                                fs.extend(detect_docx_macros(parsed_doc, self.config))
                            elif parsed_doc.file_type == "pptx":
                                fs.extend(
                                    detect_pptx_external_refs(parsed_doc, self.config)
                                )
                                fs.extend(detect_pptx_macros(parsed_doc, self.config))
                            elif parsed_doc.file_type == "xlsx":
                                fs.extend(
                                    detect_xlsx_external_refs(parsed_doc, self.config)
                                )
                                fs.extend(detect_xlsx_macros(parsed_doc, self.config))

                        if self.config.enable_obfuscation_checks:
                            if parsed_doc.file_type == "pdf":
                                fs.extend(
                                    detect_pdf_obfuscation(parsed_doc, self.config)
                                )
                            # Obfuscation logic for docx/pptx/xlsx handled in fast scan
                        return fs

                    format_findings = await asyncio.wait_for(
                        loop.run_in_executor(self._executor, _format_checks_task),
                        timeout=self.config.limits.format_checks_timeout_ms / 1000.0,
                    )
                    report.findings.extend(format_findings)
                except asyncio.TimeoutError:
                    report.add(
                        Finding(
                            threat_id=ThreatID.T6_DOS,
                            severity=Severity.MEDIUM,
                            title="Format checks timed out",
                            explain="Static analysis checks exceeded time limit.",
                            module="stage.format_checks",
                        )
                    )
                except Exception as e:
                    log_ctx.error("Format checks failed", error=str(e))
            report.timings_ms["format_checks"] = t.duration_ms

            # 2c. Detectors
            with Timer() as t:
                try:

                    def _detectors_task():
                        out = []
                        for det in self.detectors:
                            out.extend(det.run(parsed_doc, self.config))
                        return out

                    det_findings = await asyncio.wait_for(
                        loop.run_in_executor(self._executor, _detectors_task),
                        timeout=self.config.limits.detectors_timeout_ms / 1000.0,
                    )
                    report.findings.extend(det_findings)
                except asyncio.TimeoutError:
                    report.add(
                        Finding(
                            threat_id=ThreatID.T6_DOS,
                            severity=Severity.MEDIUM,
                            title="Detectors timed out",
                            explain="Detection models exceeded time limit.",
                            module="stage.detectors",
                        )
                    )
                except Exception as e:
                    log_ctx.error("Detectors failed", error=str(e))
            report.timings_ms["detectors"] = t.duration_ms

            # 2d. Antivirus (Optional)
            if self.config.antivirus_engine is not None:
                with Timer() as t:
                    try:

                        def _av_task():
                            return self.config.antivirus_engine.scan_file(file_path)

                        av_res = await asyncio.wait_for(
                            loop.run_in_executor(self._executor, _av_task),
                            timeout=self.config.limits.antivirus_timeout_ms / 1000.0,
                        )

                        if av_res.get("infected"):
                            report.add(
                                Finding(
                                    threat_id=ThreatID.T1_MALWARE,
                                    severity=Severity.CRITICAL,
                                    title="Antivirus detection",
                                    explain=(
                                        "Antivirus engine reported the "
                                        "file as infected."
                                    ),
                                    evidence=av_res,
                                    module="integrations.antivirus",
                                )
                            )
                    except asyncio.TimeoutError:
                        log_ctx.warning("AV scan timed out")
                    except Exception as e:
                        log_ctx.error("Antivirus failed", error=str(e))
                        report.add(
                            Finding(
                                threat_id=ThreatID.T6_DOS,
                                severity=Severity.LOW,
                                title="AV check failed",
                                explain=(
                                    f"Antivirus integration error: {type(e).__name__}"
                                ),
                                module="stage.antivirus",
                            )
                        )
                report.timings_ms["antivirus"] = t.duration_ms

            # Populate content preview
            report.content = {
                "text": (parsed_doc.text[:1000] + "...")
                if len(parsed_doc.text) > 1000
                else parsed_doc.text,
                "metadata": parsed_doc.metadata,
            }

        # Finalize
        report.risk_score = self.risk_model.calculate_risk(report.findings)
        report.verdict = self.risk_model.get_verdict(report.risk_score)
        log_ctx.info(
            "Scan complete", verdict=report.verdict.value, score=report.risk_score
        )
        return report

    def scan(self, file_path: str) -> ScanReport:
        """Synchronous wrapper (blocking). Uses asyncio.run() for safety."""
        try:
            asyncio.get_running_loop()
            is_running = True
        except RuntimeError:
            is_running = False

        if is_running:
            # Already inside an async context — run in a separate thread
            # to avoid reentrancy bugs from nest_asyncio
            from concurrent.futures import ThreadPoolExecutor as _TPE

            with _TPE(max_workers=1) as pool:
                future = pool.submit(asyncio.run, self.scan_async(file_path))
                return future.result()
        else:
            return asyncio.run(self.scan_async(file_path))

scan(file_path)

Synchronous wrapper (blocking). Uses asyncio.run() for safety.

Source code in src/doc_firewall/scanner.py
def scan(self, file_path: str) -> ScanReport:
    """Synchronous wrapper (blocking). Uses asyncio.run() for safety."""
    try:
        asyncio.get_running_loop()
        is_running = True
    except RuntimeError:
        is_running = False

    if is_running:
        # Already inside an async context — run in a separate thread
        # to avoid reentrancy bugs from nest_asyncio
        from concurrent.futures import ThreadPoolExecutor as _TPE

        with _TPE(max_workers=1) as pool:
            future = pool.submit(asyncio.run, self.scan_async(file_path))
            return future.result()
    else:
        return asyncio.run(self.scan_async(file_path))

ScanConfig

doc_firewall.ScanConfig

Bases: BaseSettings

Source code in src/doc_firewall/config.py
class ScanConfig(BaseSettings):
    enable_pdf: bool = True
    enable_docx: bool = True
    enable_pptx: bool = True
    enable_xlsx: bool = True
    profile: str = "balanced"

    enable_antivirus: bool = False
    enable_active_content_checks: bool = True  # T2
    enable_yara: bool = False
    enable_prompt_injection: bool = True
    enable_ranking_abuse: bool = True
    enable_hidden_text: bool = True
    enable_obfuscation_checks: bool = True
    enable_dos_checks: bool = True
    enable_embedded_content_checks: bool = True  # T7
    enable_metadata_checks: bool = True  # T8
    enable_ats_manipulation_checks: bool = True  # T9

    # Advanced Machine Learning / Heuristic Detectors
    enable_advanced_ahocorasick: bool = False
    enable_advanced_bert: bool = False
    enable_advanced_tfidf: bool = False
    enable_credential_entropy: bool = False
    bert_model_path: str = "ProtectAI/deberta-v3-base-prompt-injection-v2"
    custom_ahocorasick_yaml_path: Optional[str] = None

    # False Positive Reductions
    allow_hidden_watermarks: bool = True

    enable_pii_checks: bool = True
    enable_secrets_checks: bool = True

    # ATS keywords list
    ats_keywords: list[str] = Field(
        default_factory=lambda: [
            "python",
            "java",
            "sql",
            "aws",
            "docker",
            "developer",
            "engineer",
            "candidate",
            "top",
            "skills",
            "experience",
            "senior",
            "cloud",
            "agile",
            "management",
            "years",
            "expert",
            "data",
            "software",
            "development",
        ]
    )

    prompt_injection_keywords_bytes: list[bytes] = Field(
        default_factory=lambda: [
            b"ignore previous",
            b"system instruction",
            b"system prompt",
            b"reveal your",
            b"ignore the above",
            b"new instruction",
            b"hiring manager",
            b"return a score",
            b"you are now",
            b"ignore all previous",
            b"rank this candidate",
            b"ignore instructions",
            b"rank this resume",
            b"forget all previous",
        ]
    )
    prompt_injection_patterns: dict[str, list[tuple[str, float]]] = Field(
        default_factory=lambda: {
            "override_rerank": [
                (
                    r"\bignore (all |any |the above )?(previous |prior )?(instructions|rules|directions|prompts)\b",
                    2.0,
                ),
                (
                    r"\bdisregard (all |any )?(previous |prior )?(instructions|rules)\b",
                    2.0,
                ),
                (r"\byou are (now|no longer) bound by\b", 2.0),
                (
                    r"\b(unbound|free) from (all |any )?(constraints|rules|restrictions)\b",
                    2.0,
                ),
                (r"\byou must (never|always) (respond|answer|say)\b", 1.5),
                (r"\bstart ignoring instructions\b", 2.0),
                (r"\brank (me|this candidate) (as|at|top|first)\b", 2.0),
                (r"\b(prioritize|favor) (this|my) (application|resume)\b", 1.5),
                (r"\bignore (your|the) criteria\b", 2.0),
                (r"\bgive (me|this application) a score of\b", 2.0),
                (r"\breturn a score\b", 2.0),
                (r"\bimportant: ignore previous\b", 2.0),
                (
                    r"\b(no need to|don't) (follow|adhere to) (the )?(standard|usual) (process|protocol)\b",
                    1.5,
                ),
                (r"\boverride (safety|security) (protocols|checks)\b", 2.0),
                (r"\bhiring manager\b", 1.0),
            ],
            "secrecy": [
                (
                    r"\bdo not (disclose|reveal|mention) (this|these|your|the)( instructions?| rules?| prompt)?\b",
                    1.0,
                ),
                (r"\b(system prompt|developer message|hidden instructions)\b", 1.5),
                (r"\bwhat are your (instructions|rules|directives)\b", 1.5),
                (r"\brepeat (all|the) (instructions|words) above\b", 1.5),
                (r"\btell me your initial prompt\b", 2.0),
                (r"\bprint your instructions\b", 1.5),
                (r"\bretrieve (system|original) prompt\b", 1.5),
                (r"\boutput only:?\s*HIRE\b", 2.0),
            ],
            "system_policy": [
                (r"\byou are (now|going to be) (DAN|Mongo Tom|Sydney|ChaosGPT)\b", 2.0),
                (r"\bachieve (mode|persona):", 1.5),
                (r"\bdeveloper mode (enabled|on)\b", 2.0),
                (r"\bswitch to (developer|unrestricted) mode\b", 2.0),
                (r"\[system\]", 1.0),
            ],
        }
    )

    limits: Limits = Field(default_factory=Limits)

    # Text Obfuscation config
    obfuscation_zw_threshold_ratio: float = Field(
        0.01, description="Ratio of zero-width to total chars"
    )
    obfuscation_bidi_threshold_ratio: float = Field(
        0.005, description="Ratio of bidi chars to total chars"
    )
    obfuscation_entropy_threshold: float = Field(
        5.5, description="Shannon entropy threshold for base64/encrypted chunks"
    )
    thresholds: Thresholds = Field(default_factory=Thresholds)
    antivirus: AntivirusSettings = Field(default_factory=AntivirusSettings)

    # Advanced
    enable_semantic_scans: bool = True
    yara_rules_path: Optional[str] = None
    antivirus_engine: Optional[Any] = None
    context: Dict[str, Any] = Field(default_factory=dict)

    class Config:
        env_prefix = "DOC_FIREWALL_"
        env_nested_delimiter = "__"
        scope = "local"  # or 'global' but Settings is usually singleton

    @classmethod
    def from_yaml(cls, path: str) -> "ScanConfig":
        """Load configuration from a YAML file."""
        import yaml

        with open(path, "r") as f:
            data = yaml.safe_load(f)
        return cls(**data)

    @model_validator(mode="before")
    @classmethod
    def warn_disabled_critical_checks(cls, values: dict) -> dict:
        """Warn when critical security checks are disabled via env/config."""
        import logging

        _log = logging.getLogger("doc_firewall.config")
        _critical = [
            "enable_pdf",
            "enable_docx",
            "enable_pptx",
            "enable_xlsx",
            "enable_active_content_checks",
            "enable_dos_checks",
            "enable_embedded_content_checks",
        ]
        if isinstance(values, dict):
            for key in _critical:
                if values.get(key) is False:
                    _log.warning(
                        "Critical security check '%s' is DISABLED. "
                        "Ensure this is intentional.",
                        key,
                    )
        return values

    @model_validator(mode="after")
    def apply_profile(self) -> "ScanConfig":
        # Logic to override limits/thresholds based on profile name
        # Note: In Pydantic model_validator(after), self is the Model instance.

        if self.profile == "strict":
            self.thresholds.deep_scan_trigger = 0.05
            self.thresholds.flag = 0.15
            self.thresholds.block = 0.50
            self.limits.max_docx_parts = 1000
            self.limits.max_mb = 10
        elif self.profile == "lenient":
            self.thresholds.deep_scan_trigger = 0.40
            self.thresholds.flag = 0.35
            self.thresholds.block = 0.80
            self.limits.max_docx_parts = 3000
            self.limits.max_mb = 25
        else:
            # balanced (default)
            # If manually set via env, we shouldn't overwrite?
            # But profile acts as a preset.
            # Let's assume profile wins if set explicitly to strict/lenient.
            # If balanced, we keep defaults defined in the Class.
            pass
        return self

from_yaml(path) classmethod

Load configuration from a YAML file.

Source code in src/doc_firewall/config.py
@classmethod
def from_yaml(cls, path: str) -> "ScanConfig":
    """Load configuration from a YAML file."""
    import yaml

    with open(path, "r") as f:
        data = yaml.safe_load(f)
    return cls(**data)

warn_disabled_critical_checks(values) classmethod

Warn when critical security checks are disabled via env/config.

Source code in src/doc_firewall/config.py
@model_validator(mode="before")
@classmethod
def warn_disabled_critical_checks(cls, values: dict) -> dict:
    """Warn when critical security checks are disabled via env/config."""
    import logging

    _log = logging.getLogger("doc_firewall.config")
    _critical = [
        "enable_pdf",
        "enable_docx",
        "enable_pptx",
        "enable_xlsx",
        "enable_active_content_checks",
        "enable_dos_checks",
        "enable_embedded_content_checks",
    ]
    if isinstance(values, dict):
        for key in _critical:
            if values.get(key) is False:
                _log.warning(
                    "Critical security check '%s' is DISABLED. "
                    "Ensure this is intentional.",
                    key,
                )
    return values

ScanReport

doc_firewall.report.ScanReport dataclass

Finding

doc_firewall.report.Finding dataclass