feat: auth
This commit is contained in:
50
main.py
50
main.py
@@ -1,16 +1,29 @@
|
||||
import os
|
||||
from secrets import compare_digest
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import Depends, FastAPI, HTTPException, Security
|
||||
from openai import APITimeoutError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security.api_key import APIKeyHeader
|
||||
|
||||
from schemas import RoutePlanRequest, RoutePlanResult
|
||||
from agent import ConfigurationError, GuardrailError, run_route_plan
|
||||
from schemas import LoadPlanRequest, LoadPlanResult, RoutePlanRequest, RoutePlanResult
|
||||
from agent import (
|
||||
ConfigurationError,
|
||||
GuardrailError,
|
||||
LoadPlanConfigurationError,
|
||||
LoadPlanGuardrailError,
|
||||
LoadPlanToolConfigurationError,
|
||||
LoadPlanToolRequestError,
|
||||
run_load_plan,
|
||||
run_route_plan,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
|
||||
def _csv_env(name: str, default: str) -> list[str]:
|
||||
raw_value = os.getenv(name, default)
|
||||
@@ -21,6 +34,19 @@ def _bool_env(name: str, default: bool) -> bool:
|
||||
raw_value = os.getenv(name, "true" if default else "false").strip().lower()
|
||||
return raw_value in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _required_env(name: str) -> str:
|
||||
value = os.getenv(name, "").strip()
|
||||
if not value:
|
||||
raise HTTPException(status_code=503, detail=f"Missing required environment variable: {name}")
|
||||
return value
|
||||
|
||||
|
||||
def require_authorization(api_key: str | None = Security(api_key_header)) -> None:
|
||||
expected_key = _required_env("AGENT_HTTP_AUTH_KEY")
|
||||
if api_key is None or not compare_digest(api_key, expected_key):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
app = FastAPI(title="Geo Route Agent", version="0.1.0")
|
||||
|
||||
app.add_middleware(
|
||||
@@ -36,7 +62,7 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
@app.post("/route/plan", response_model=RoutePlanResult)
|
||||
@app.post("/route/plan", response_model=RoutePlanResult, dependencies=[Depends(require_authorization)])
|
||||
async def route_plan(request: RoutePlanRequest) -> RoutePlanResult:
|
||||
try:
|
||||
return await run_route_plan(request)
|
||||
@@ -50,6 +76,22 @@ async def route_plan(request: RoutePlanRequest) -> RoutePlanResult:
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@app.post("/load/plan", response_model=LoadPlanResult, dependencies=[Depends(require_authorization)])
|
||||
async def load_plan(request: LoadPlanRequest) -> LoadPlanResult:
|
||||
try:
|
||||
return await run_load_plan(request)
|
||||
except (LoadPlanConfigurationError, LoadPlanToolConfigurationError) as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||||
except LoadPlanGuardrailError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
except LoadPlanToolRequestError as exc:
|
||||
raise HTTPException(status_code=502, detail=str(exc)) from exc
|
||||
except (httpx.TimeoutException, APITimeoutError, TimeoutError) as exc:
|
||||
raise HTTPException(status_code=504, detail=f"Upstream request timed out: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
Reference in New Issue
Block a user