#!/usr/bin/env python3
"""
Validate an S3 access point policy against the defence-in-depth rules.

This checker intentionally goes beyond the baseline pattern in
https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-points-policies.html by
requiring explicit deny statements so that the access point stays read-only and
prefix-scoped even if the IAM role later gains broader permissions.

Rules enforced (per access point policy):
- Only the expected IAM role is granted access through the access point.
- Object resources for that role are scoped to the expected prefix.
- Any s3:ListBucket allows for that role are limited to the same prefix.
- Explicit deny statements provide the backstop (principal lockdown, write guard,
  prefix guard for list/get).
"""
from __future__ import annotations

import argparse
import json
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable, List, Optional, Sequence, Set, Tuple


class ColourPrinter:
    """Simple ANSI colour helper with auto/always/never modes."""

    CODES = {
        "red": "\033[31m",
        "green": "\033[32m",
        "yellow": "\033[33m",
        "blue": "\033[34m",
        "reset": "\033[0m",
        "bold": "\033[1m",
    }

    def __init__(self, mode: str) -> None:
        if mode not in {"auto", "always", "never"}:
            raise ValueError(f"unsupported colour mode: {mode}")
        if mode == "always":
            self.enabled = True
        elif mode == "never":
            self.enabled = False
        else:
            self.enabled = sys.stdout.isatty()

    def colour(self, text: str, colour: str) -> str:
        if not self.enabled or colour not in self.CODES:
            return text
        return f"{self.CODES[colour]}{text}{self.CODES['reset']}"

    def bold(self, text: str) -> str:
        if not self.enabled:
            return text
        return f"{self.CODES['bold']}{text}{self.CODES['reset']}"


def load_policy(path: Path) -> dict:
    try:
        with path.open("r", encoding="utf-8") as f:
            return json.load(f)
    except FileNotFoundError:
        raise SystemExit(f"error: policy file not found: {path}")
    except json.JSONDecodeError as exc:
        raise SystemExit(f"error: invalid JSON in {path}: {exc}")


def to_list(value) -> List:
    if value is None:
        return []
    if isinstance(value, list):
        return value
    return [value]


def normalise_principal(principal) -> List[str]:
    if principal == "*":
        return ["*"]
    if isinstance(principal, dict):
        principals = []
        for key, value in principal.items():
            principals.extend(to_list(value))
        return principals
    return to_list(principal)


def normalise_actions(actions) -> List[str]:
    if actions == "*":
        return ["*"]
    return to_list(actions)


def normalise_resources(resources) -> List[str]:
    return to_list(resources)


def principal_matches(principals: Sequence[str], target: str) -> bool:
    return target in principals


def infer_prefix_from_objects(resources: Iterable[str]) -> Optional[str]:
    prefixes = set()
    for res in resources:
        parts = res.split("/object/", 1)
        if len(parts) != 2:
            continue
        suffix = parts[1]
        if suffix.endswith("*"):
            suffix = suffix[:-1]
        prefixes.add(suffix)
    if len(prefixes) == 1:
        prefix = prefixes.pop()
        return prefix
    return None


@dataclass
class Finding:
    message: str
    actual: str = ""
    expected: str = ""


@dataclass
class ComplianceResult:
    compliant: bool
    findings: List[Finding] = field(default_factory=list)

    def add_failure(self, message: str, actual: str = "", expected: str = "") -> None:
        self.compliant = False
        self.findings.append(Finding(message, actual, expected))


def evaluate_policy(policy: dict, role_arn: str, prefix_hint: Optional[str]) -> ComplianceResult:
    statements = policy.get("Statement", [])
    if not isinstance(statements, list):
        statements = [statements]

    result = ComplianceResult(compliant=True)

    role_allow_statements: List[Tuple[List[str], List[str], List[str], dict, dict]] = []
    wildcard_allow_statements: List[dict] = []
    other_allow_statements: List[dict] = []
    object_resources: List[Tuple[str, dict]] = []
    prefix_candidates: Set[str] = set()

    for stmt in statements:
        effect = stmt.get("Effect")
        principals = normalise_principal(stmt.get("Principal"))
        actions = normalise_actions(stmt.get("Action"))
        resources = normalise_resources(stmt.get("Resource"))
        not_resources = normalise_resources(stmt.get("NotResource"))
        condition = stmt.get("Condition") or {}

        if effect != "Allow":
            continue

        principal_set = set(principals)
        if "*" in principal_set:
            wildcard_allow_statements.append(stmt)

        if role_arn in principal_set:
            role_allow_statements.append((actions, resources, not_resources, condition, stmt))
            if principal_set != {role_arn}:
                other_allow_statements.append(stmt)

            for res in resources + not_resources:
                if "/object/" in res:
                    suffix = res.split("/object/", 1)[1]
                    if suffix.endswith("*"):
                        suffix = suffix[:-1]
                    prefix_candidates.add(suffix.rstrip("/"))
                    object_resources.append((res, stmt))
        else:
            if principals and principals != ["*"]:
                other_allow_statements.append(stmt)

    if wildcard_allow_statements:
        actual = "\n".join(json.dumps(stmt, indent=4) for stmt in wildcard_allow_statements)
        result.add_failure(
            "Policy allows '*' principal. Restrict the access point policy to the expected role.",
            actual=actual,
            expected=json.dumps(
                {
                    "Effect": "Allow",
                    "Principal": {"AWS": role_arn},
                    "Action": ["s3:GetObject"],
                    "Resource": "arn:aws:s3:REGION:ACCOUNT:accesspoint/ACCESS_POINT/object/PREFIX/*",
                },
                indent=4,
            ),
        )

    other_allow_filtered = [
        stmt
        for stmt in other_allow_statements
        if role_arn not in set(normalise_principal(stmt.get("Principal")))
        or set(normalise_principal(stmt.get("Principal"))) != {role_arn}
    ]
    if other_allow_filtered:
        actual = "\n".join(json.dumps(stmt, indent=4) for stmt in other_allow_filtered)
        result.add_failure(
            "Policy grants Allow permissions to principals other than the expected role.",
            actual=actual,
            expected=json.dumps(
                {
                    "Effect": "Allow",
                    "Principal": {"AWS": role_arn},
                    "Action": ["s3:GetObject"],
                    "Resource": "arn:aws:s3:REGION:ACCOUNT:accesspoint/ACCESS_POINT/object/PREFIX/*",
                },
                indent=4,
            ),
        )

    if not role_allow_statements:
        result.add_failure(
            "No Allow statement found for the specified role.",
            actual="<none>",
            expected=json.dumps(
                {
                    "Effect": "Allow",
                    "Principal": {"AWS": role_arn},
                    "Action": ["s3:GetObject"],
                    "Resource": "arn:aws:s3:REGION:ACCOUNT:accesspoint/ACCESS_POINT/object/PREFIX/*",
                },
                indent=4,
            ),
        )

    prefix = prefix_hint
    inferred = {p for p in prefix_candidates if p}
    if prefix is None:
        if len(inferred) == 1:
            prefix = f"{inferred.pop()}/"
        elif inferred:
            result.add_failure(
                "Unable to determine a single object prefix from the policy.",
                actual="Detected prefixes: " + ", ".join(sorted(inferred)),
                expected="Ensure all object resources end with object/<prefix>* or pass --prefix explicitly.",
            )
        elif object_resources:
            result.add_failure(
                "Object resources are not scoped to a prefix.",
                actual="\n".join(res for res, _ in object_resources),
                expected="Use resources like .../object/s3accesslogs/* or supply --prefix.",
            )
    if not object_resources:
        result.add_failure(
            "No object-level resources found in Allow statements for the role.",
            actual="<none>",
            expected="Include at least one Allow with Resource .../object/<prefix>*.",
        )
    else:
        if prefix:
            expected_prefix = prefix.rstrip("/")
            mismatched_resources = []
            for res, stmt in object_resources:
                suffix = res.split("/object/", 1)[1]
                wildcard = suffix.endswith("*")
                suffix = suffix[:-1] if wildcard else suffix
                normalized = suffix.rstrip("/")
                if not normalized:
                    mismatched_resources.append((res, stmt))
                elif normalized != expected_prefix:
                    mismatched_resources.append((res, stmt))
                elif not wildcard:
                    mismatched_resources.append((res, stmt))

            if mismatched_resources:
                actual = "\n".join(json.dumps(stmt, indent=4) for _, stmt in mismatched_resources)
                result.add_failure(
                    f"Object resources for the role must be limited to '{prefix}'.",
                    actual=actual,
                    expected=json.dumps(
                        {
                            "Resource": [
                                f"arn:aws:s3:REGION:ACCOUNT:accesspoint/ACCESS_POINT/object/{prefix}*"
                            ]
                        },
                        indent=4,
                    ),
                )

        list_statements = []
        for actions, resources, _, condition, stmt in role_allow_statements:
            if any(action.startswith("s3:ListBucket") for action in actions):
                list_statements.append((resources, condition, stmt))
        if prefix and list_statements:
            expected_prefix = prefix.rstrip("/")
            guarded = False
            for resources, condition, _ in list_statements:
                if any("/object/" in res for res in resources):
                    continue
                prefixes = set()
                for key in ("StringLike", "StringEquals", "StringEqualsIfExists", "StringLikeIfExists"):
                    values = to_list((condition.get(key) or {}).get("s3:prefix"))
                    prefixes.update(v.rstrip("*").rstrip("/") for v in values if v)
                if expected_prefix in prefixes:
                    guarded = True
            if not guarded:
                actual = "\n".join(json.dumps(stmt, indent=4) for _, _, stmt in list_statements)
                result.add_failure(
                    f"'s3:ListBucket' allows are not limited to the '{prefix}' prefix.",
                    actual=actual,
                    expected=json.dumps(
                        {
                            "Effect": "Allow",
                            "Principal": {"AWS": role_arn},
                            "Action": "s3:ListBucket",
                            "Resource": "arn:aws:s3:REGION:ACCOUNT:accesspoint/ACCESS_POINT",
                            "Condition": {"StringLike": {"s3:prefix": f"{prefix}*"}},
                        },
                        indent=4,
                    ),
                )

    # Defence-in-depth checks: require explicit deny guardrails.
    deny_non_role = False
    denied_write_actions: Set[str] = set()
    has_wildcard_write_deny = False
    deny_list_bucket: List[Tuple[dict, List[str], dict]] = []
    deny_get_notresource: List[Tuple[dict, List[str], List[str], dict]] = []
    deny_star_statements: List[dict] = []

    for stmt in statements:
        if stmt.get("Effect") != "Deny":
            continue

        principals = normalise_principal(stmt.get("Principal"))
        actions = normalise_actions(stmt.get("Action"))
        resources = normalise_resources(stmt.get("Resource"))
        not_resources = normalise_resources(stmt.get("NotResource"))
        condition = stmt.get("Condition") or {}

        if principals == ["*"]:
            deny_star_statements.append(stmt)
            string_not_equals = condition.get("StringNotEquals") or {}
            if string_not_equals.get("aws:PrincipalArn") == role_arn:
                deny_non_role = True

        if principal_matches(principals, role_arn):
            if "*" in actions or "s3:*" in actions:
                has_wildcard_write_deny = True
            denied_write_actions.update(actions)
            if any(action.startswith("s3:ListBucket") for action in actions):
                deny_list_bucket.append((condition, resources, stmt))
            if (
                "s3:GetObject" in actions
                or "*" in actions
                or "s3:*" in actions
            ):
                deny_get_notresource.append((condition, resources, not_resources, stmt))

    if not deny_non_role:
        actual_text = (
            "\n".join(json.dumps(stmt, indent=4) for stmt in deny_star_statements)
            if deny_star_statements
            else "<none>"
        )
        expected_statement = json.dumps(
            {
                "Effect": "Deny",
                "Principal": "*",
                "Action": "s3:*",
                "Resource": [
                    "arn:aws:s3:REGION:ACCOUNT:accesspoint/ACCESS_POINT",
                    "arn:aws:s3:REGION:ACCOUNT:accesspoint/ACCESS_POINT/object/*",
                ],
                "Condition": {
                    "StringNotEquals": {"aws:PrincipalArn": role_arn}
                },
            },
            indent=4,
        )
        result.add_failure(
            "Missing deny statement that blocks every principal except the expected role.",
            actual=actual_text,
            expected=expected_statement,
        )

    required_write_actions = {
        "s3:AbortMultipartUpload",
        "s3:DeleteObject",
        "s3:DeleteObjectVersion",
        "s3:PutObject",
        "s3:PutObjectAcl",
        "s3:PutObjectTagging",
        "s3:PutObjectVersionAcl",
        "s3:PutObjectVersionTagging",
        "s3:RestoreObject",
    }
    if not (has_wildcard_write_deny or required_write_actions.issubset(denied_write_actions)):
        denied_list = ", ".join(sorted(denied_write_actions)) if denied_write_actions else "<none>"
        result.add_failure(
            "Allowed role must be explicitly denied write-style S3 object actions.",
            actual=f"Denied actions observed: {denied_list}",
            expected="At minimum deny: " + ", ".join(sorted(required_write_actions)),
        )

    if prefix is None:
        inferred_from_denies = set()
        for _, resources, not_resources, _ in deny_get_notresource:
            candidate = infer_prefix_from_objects(resources + not_resources)
            if candidate:
                inferred_from_denies.add(candidate.rstrip("/"))
        if len(inferred_from_denies) == 1:
            prefix = f"{inferred_from_denies.pop()}/"

    if prefix:
        expected_prefix = prefix.rstrip("/")
    else:
        expected_prefix = None

    if not deny_list_bucket:
        result.add_failure(
            "Missing deny statement that targets s3:ListBucket for the allowed role.",
            actual="<none>",
            expected=json.dumps(
                {
                    "Effect": "Deny",
                    "Principal": {"AWS": role_arn},
                    "Action": "s3:ListBucket",
                    "Resource": "arn:aws:s3:REGION:ACCOUNT:accesspoint/ACCESS_POINT",
                    "Condition": {
                        "StringNotLike": {"s3:prefix": "PREFIX*"}
                    },
                },
                indent=4,
            ),
        )
    elif expected_prefix:
        guarded = False
        mismatched = []
        for condition, resources, stmt in deny_list_bucket:
            if any("/object/" in res for res in resources):
                mismatched.append(stmt)
                continue
            prefixes = set()
            for key in ("StringNotLike", "StringLike"):
                values = to_list((condition.get(key) or {}).get("s3:prefix"))
                prefixes.update(v.rstrip("*").rstrip("/") for v in values if v)
            if expected_prefix in prefixes:
                guarded = True
            else:
                mismatched.append(stmt)
        if not guarded:
            actual = "\n".join(json.dumps(stmt, indent=4) for stmt in mismatched) or "<none>"
            result.add_failure(
                f"Deny statements for s3:ListBucket must guard the '{prefix}' prefix.",
                actual=actual,
                expected=json.dumps(
                    {
                        "Effect": "Deny",
                        "Principal": {"AWS": role_arn},
                        "Action": "s3:ListBucket",
                        "Resource": "arn:aws:s3:REGION:ACCOUNT:accesspoint/ACCESS_POINT",
                        "Condition": {"StringNotLike": {"s3:prefix": f"{prefix}*"}},
                    },
                    indent=4,
                ),
            )

    if not deny_get_notresource:
        result.add_failure(
            "Missing deny statement that targets s3:GetObject for the allowed role.",
            actual="<none>",
            expected=json.dumps(
                {
                    "Effect": "Deny",
                    "Principal": {"AWS": role_arn},
                    "Action": "s3:GetObject",
                    "NotResource": [
                        "arn:aws:s3:REGION:ACCOUNT:accesspoint/ACCESS_POINT/object/PREFIX*"
                    ],
                },
                indent=4,
            ),
        )
    elif expected_prefix:
        guarded = False
        problematic = []
        for condition, resources, not_resources, stmt in deny_get_notresource:
            combined = resources + not_resources
            candidate = infer_prefix_from_objects(combined)
            if candidate and candidate.rstrip("/") == expected_prefix:
                guarded = True
            else:
                problematic.append(stmt)
        if not guarded:
            actual = "\n".join(json.dumps(stmt, indent=4) for stmt in problematic) or "<none>"
            result.add_failure(
                f"Deny statements for s3:GetObject must exclude objects outside '{prefix}'.",
                actual=actual,
                expected=json.dumps(
                    {
                        "Effect": "Deny",
                        "Principal": {"AWS": role_arn},
                        "Action": "s3:GetObject",
                        "NotResource": [
                            f"arn:aws:s3:REGION:ACCOUNT:accesspoint/ACCESS_POINT/object/{prefix}*"
                        ],
                    },
                    indent=4,
                ),
            )

    return result


def main(argv: Optional[Sequence[str]] = None) -> int:
    parser = argparse.ArgumentParser(
        description="Validate an S3 access point policy against the defence-in-depth policy."
    )
    parser.add_argument("policy", type=Path, help="Path to the policy JSON file.")
    parser.add_argument(
        "--role",
        required=True,
        help="Expected IAM role ARN allowed to use the access point.",
    )
    parser.add_argument(
        "--prefix",
        help="Expected key prefix (e.g. s3accesslogs/). "
        "If omitted the script will attempt to infer it.",
    )
    parser.add_argument(
        "--color",
        choices=["auto", "always", "never"],
        default="auto",
        help="Colourise output (default: auto)",
    )
    args = parser.parse_args(argv)

    policy = load_policy(args.policy)
    result = evaluate_policy(policy, args.role, args.prefix)

    printer = ColourPrinter(args.color)

    if result.compliant:
        print(printer.colour("Policy is compliant.", "green"))
        return 0

    print(printer.colour("Policy is NOT compliant.", "red"))
    for finding in result.findings:
        print(f"- {printer.bold(finding.message)}")
        if finding.actual:
            label = printer.colour("actual:", "red")
            print(f"  {label}")
            for line in finding.actual.splitlines():
                print(printer.colour(f"    {line}", "red"))
        if finding.expected:
            label = printer.colour("expected:", "green")
            print(f"  {label}")
            for line in finding.expected.splitlines():
                print(printer.colour(f"    {line}", "green"))
    return 1


if __name__ == "__main__":
    sys.exit(main())
