/** * Per-port AI budget enforcement. * * Budgets are denominated in tokens (input + output) over a rolling * window (day / week / month). Two thresholds: * - softCapTokens: log a warning, surface a banner, but allow the call * - hardCapTokens: refuse the call until the period rolls over * * Stored in `system_settings` under key `ai.budget` per port. Usage is * accumulated in `ai_usage_ledger` and rolled up by SQL. */ import { and, eq, gte, sql } from 'drizzle-orm'; import { db } from '@/lib/db'; import { aiUsageLedger } from '@/lib/db/schema/ai-usage'; import { systemSettings } from '@/lib/db/schema/system'; import { logger } from '@/lib/logger'; export type BudgetPeriod = 'day' | 'week' | 'month'; export interface AiBudget { /** When false, the budget is disabled — no caps enforced. */ enabled: boolean; softCapTokens: number; hardCapTokens: number; period: BudgetPeriod; } const KEY = 'ai.budget'; const DEFAULT_BUDGET: AiBudget = { enabled: false, softCapTokens: 100_000, hardCapTokens: 500_000, period: 'month', }; async function readBudget(portId: string): Promise { const [row] = await db .select() .from(systemSettings) .where(and(eq(systemSettings.key, KEY), eq(systemSettings.portId, portId))); if (!row) return { ...DEFAULT_BUDGET }; const v = row.value as Partial; return { enabled: v.enabled === true, softCapTokens: typeof v.softCapTokens === 'number' ? v.softCapTokens : DEFAULT_BUDGET.softCapTokens, hardCapTokens: typeof v.hardCapTokens === 'number' ? v.hardCapTokens : DEFAULT_BUDGET.hardCapTokens, period: v.period === 'day' || v.period === 'week' || v.period === 'month' ? v.period : 'month', }; } export async function getAiBudget(portId: string): Promise { return readBudget(portId); } export async function setAiBudget( portId: string, input: Partial, userId: string, ): Promise { const existing = await readBudget(portId); const next: AiBudget = { enabled: input.enabled ?? existing.enabled, softCapTokens: input.softCapTokens ?? existing.softCapTokens, hardCapTokens: input.hardCapTokens ?? existing.hardCapTokens, period: input.period ?? existing.period, }; if (next.softCapTokens < 0 || next.hardCapTokens < 0) { throw new Error('Token caps must be non-negative'); } if (next.softCapTokens > next.hardCapTokens) { throw new Error('softCapTokens cannot exceed hardCapTokens'); } await db .delete(systemSettings) .where(and(eq(systemSettings.key, KEY), eq(systemSettings.portId, portId))); await db.insert(systemSettings).values({ key: KEY, portId, value: next as unknown as Record, updatedBy: userId, }); return next; } /** Returns the start-of-period UTC timestamp for the configured window. */ export function periodStart(period: BudgetPeriod, now: Date = new Date()): Date { const start = new Date(now); start.setUTCHours(0, 0, 0, 0); if (period === 'day') return start; if (period === 'week') { // Reset to Monday 00:00 UTC. const dow = (start.getUTCDay() + 6) % 7; // 0 = Monday start.setUTCDate(start.getUTCDate() - dow); return start; } // month start.setUTCDate(1); return start; } /** Total tokens used in the current period, optionally filtered by feature. */ export async function currentPeriodTokens(portId: string, feature?: string): Promise { const budget = await readBudget(portId); const since = periodStart(budget.period); const filters = [eq(aiUsageLedger.portId, portId), gte(aiUsageLedger.createdAt, since)]; if (feature) filters.push(eq(aiUsageLedger.feature, feature)); const [row] = await db .select({ total: sql`coalesce(sum(${aiUsageLedger.totalTokens}), 0)` }) .from(aiUsageLedger) .where(and(...filters)); return Number(row?.total ?? 0); } export type BudgetCheckResult = | { ok: true; remaining: number; usedTokens: number; softCap: boolean } | { ok: false; reason: 'hard-cap-exceeded' | 'budget-disabled-but-no-key' | 'estimated-exceeds-cap'; usedTokens: number; capTokens: number; }; /** * Pre-flight gate: should we let this call proceed? Pass an `estimatedTokens` * value (e.g. max_tokens budget for the request) so we can refuse calls * that would *guarantee* hitting the cap, not just blow past it later. */ export async function checkBudget(args: { portId: string; estimatedTokens: number; }): Promise { const { portId, estimatedTokens } = args; const budget = await readBudget(portId); if (!budget.enabled) { // Budget is off — usage still gets logged, but no caps enforced. return { ok: true, remaining: Number.POSITIVE_INFINITY, usedTokens: 0, softCap: false }; } const used = await currentPeriodTokens(portId); const remaining = budget.hardCapTokens - used; if (remaining <= 0) { return { ok: false, reason: 'hard-cap-exceeded', usedTokens: used, capTokens: budget.hardCapTokens, }; } if (estimatedTokens > remaining) { return { ok: false, reason: 'estimated-exceeds-cap', usedTokens: used, capTokens: budget.hardCapTokens, }; } return { ok: true, remaining, usedTokens: used, softCap: used > budget.softCapTokens, }; } interface RecordUsageInput { portId: string; userId?: string | null; feature: string; provider: string; model: string; inputTokens: number; outputTokens: number; requestId?: string | null; } /** Insert a ledger row. Never throws — logged failures degrade silently. */ export async function recordAiUsage(input: RecordUsageInput): Promise { try { const total = (input.inputTokens || 0) + (input.outputTokens || 0); await db.insert(aiUsageLedger).values({ portId: input.portId, userId: input.userId ?? null, feature: input.feature, provider: input.provider, model: input.model, inputTokens: input.inputTokens, outputTokens: input.outputTokens, totalTokens: total, requestId: input.requestId ?? null, }); } catch (err) { // Don't fail the user-facing call because the ledger write hiccuped — // we'd rather silently lose a row than blow up an OCR scan. logger.error({ err, feature: input.feature }, 'recordAiUsage failed'); } } /** Per-feature breakdown for the current period — feeds the admin dashboard. */ export async function periodBreakdown( portId: string, ): Promise> { const budget = await readBudget(portId); const since = periodStart(budget.period); const rows = await db .select({ feature: aiUsageLedger.feature, tokens: sql`coalesce(sum(${aiUsageLedger.totalTokens}), 0)`, calls: sql`count(*)::int`, }) .from(aiUsageLedger) .where(and(eq(aiUsageLedger.portId, portId), gte(aiUsageLedger.createdAt, since))) .groupBy(aiUsageLedger.feature); return rows.map((r) => ({ feature: r.feature, tokens: Number(r.tokens), calls: Number(r.calls), })); }