feat(saas): SquareMCP v2 — multi-tenant MCP platform complete
Steps 0–10 of the v2 plan, 194 tests passing. Core infrastructure - Shared Redis client (src/redis.ts); all four Redis consumers migrated - Vitest test harness with vitest.config.ts and npm test/test:watch scripts Billing & invoicing (Steps 1–2) - Monthly invoice generation with idempotency (MySQL uq_customer_period unique key) - Cron job with Redis distributed lock (Lua compare-delete, 1-hr TTL) - Invoice emailer via nodemailer (FETCHERPAY SMTP) - Billing middleware: checkLimit gate in handleToolCall; platform attribution fix Email multi-tenancy (Step 3) - EmailCtx = Account | EmailCredentials; imap.ts + smtp.ts accept both - resolveEmailCtx helper in tools.ts; all email tools use customer credentials Analytics + platform health (Steps 4–5) - Chart.js bar charts for platform breakdown and daily activity - Token expiry check in getCredential with dynamic import refresh - platform-health.ts: per-platform health probe with 10-min Redis cache - GET /api/health/platforms; "Token expired" amber badge in dashboard Tool schema filtering (Step 6) - stripAccountParam deep-clones tool schemas; multi-tenant sessions never see the internal account enum OAuth hardening (Step 7) - Atomic auth code consumption: UPDATE SET used=TRUE, check affectedRows - customer_id threaded through oauth_auth_codes → oauth_tokens - getTokenCustomer(); requireAuth resolves req.customer from Bearer token - Consent page requires authenticated session; redirect_uri validated against registered URIs; http://localhost:* loopback wildcard DCR browser flow (Step 8) - ensureOAuthAppRegistered() upserts pre-registered SquareMCP OAuth app on startup with redirect URIs for mcp-callback, localhost:*, claude-desktop, opencode - GET /oauth/connect-mcp → server-side redirect (client_id off frontend) - GET /oauth/mcp-callback → exchanges code, renders config snippet page with copy buttons for Claude Desktop and Codex CLI Webhooks (Step 9) - webhook_url + webhook_secret columns on customers - deliverWebhook(): HMAC-SHA256 signing, 3× exponential retry (1s/4s/16s), Redis DLQ with 7-day TTL on total failure - isValidWebhookUrl(): SSRF protection (blocks RFC-1918, localhost, .local) - POST /api/webhooks/config (secret returned once), GET, DELETE - GET /api/admin/webhooks/dlq/:customerId - WhatsApp POST route uses express.raw() for raw body preservation - Dashboard Webhooks tab with secret-once display and copy button Developer docs (Step 10) - docs/ static HTML site (GitHub Pages, no build pipeline) - index.html: landing page with client + platform overview - getting-started.html: tabbed MCP config for Claude Desktop, Codex CLI, opencode - platforms.html: LinkedIn, TikTok, WhatsApp, Instagram, Twitter, Telegram guides - agent-tutorial.html: complete Node.js agent (Anthropic SDK + MCP SDK), LinkedIn posting loop, extensions for multi-platform + inbound webhook reaction Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
96
src/billing/cron.test.ts
Normal file
96
src/billing/cron.test.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
|
||||
const { mockRedisSet, mockRedisEval } = vi.hoisted(() => ({
|
||||
mockRedisSet: vi.fn(),
|
||||
mockRedisEval: vi.fn(),
|
||||
}));
|
||||
|
||||
const { mockQuery } = vi.hoisted(() => ({ mockQuery: vi.fn() }));
|
||||
|
||||
const { mockGenerateMonthlyInvoice } = vi.hoisted(() => ({
|
||||
mockGenerateMonthlyInvoice: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../redis.js', () => ({
|
||||
default: { set: mockRedisSet, eval: mockRedisEval },
|
||||
}));
|
||||
|
||||
vi.mock('../db.js', () => ({
|
||||
getPool: () => ({ query: mockQuery }),
|
||||
}));
|
||||
|
||||
vi.mock('./invoices.js', () => ({
|
||||
generateMonthlyInvoice: (...args: any[]) => mockGenerateMonthlyInvoice(...args),
|
||||
}));
|
||||
|
||||
import { runInvoiceCron } from './cron.js';
|
||||
|
||||
const TWO_CUSTOMERS = [[{ id: 'cust-a' }, { id: 'cust-b' }]];
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockRedisEval.mockResolvedValue(1);
|
||||
mockGenerateMonthlyInvoice.mockResolvedValue(null);
|
||||
});
|
||||
|
||||
describe('runInvoiceCron — Redis lock', () => {
|
||||
it('runs invoice loop when lock is acquired', async () => {
|
||||
mockRedisSet.mockResolvedValue('OK');
|
||||
mockQuery.mockResolvedValue(TWO_CUSTOMERS);
|
||||
|
||||
await runInvoiceCron();
|
||||
|
||||
expect(mockRedisSet).toHaveBeenCalledWith(
|
||||
'invoice:cron:lock',
|
||||
expect.any(String),
|
||||
{ NX: true, EX: 3600 }
|
||||
);
|
||||
expect(mockQuery).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('skips invoice loop when lock is already held', async () => {
|
||||
mockRedisSet.mockResolvedValue(null);
|
||||
|
||||
await runInvoiceCron();
|
||||
|
||||
expect(mockQuery).not.toHaveBeenCalled();
|
||||
expect(mockGenerateMonthlyInvoice).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('releases lock via compare-delete Lua script after success', async () => {
|
||||
mockRedisSet.mockResolvedValue('OK');
|
||||
mockQuery.mockResolvedValue([[{ id: 'cust-1' }]]);
|
||||
|
||||
await runInvoiceCron();
|
||||
|
||||
expect(mockRedisEval).toHaveBeenCalledWith(
|
||||
expect.stringContaining('redis.call("get"'),
|
||||
expect.objectContaining({ keys: ['invoice:cron:lock'], arguments: [expect.any(String)] })
|
||||
);
|
||||
});
|
||||
|
||||
it('does NOT release lock when the loop throws', async () => {
|
||||
mockRedisSet.mockResolvedValue('OK');
|
||||
mockQuery.mockRejectedValue(new Error('DB down'));
|
||||
|
||||
await expect(runInvoiceCron()).rejects.toThrow('DB down');
|
||||
expect(mockRedisEval).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('runInvoiceCron — per-customer error isolation', () => {
|
||||
it('continues processing remaining customers after one fails', async () => {
|
||||
mockRedisSet.mockResolvedValue('OK');
|
||||
mockQuery.mockResolvedValue([[{ id: 'cust-a' }, { id: 'cust-b' }, { id: 'cust-c' }]]);
|
||||
|
||||
mockGenerateMonthlyInvoice
|
||||
.mockResolvedValueOnce({ invoice_number: 'SMCP-1', customer_id: 'cust-a' })
|
||||
.mockRejectedValueOnce(new Error('Stripe error'))
|
||||
.mockResolvedValueOnce({ invoice_number: 'SMCP-3', customer_id: 'cust-c' });
|
||||
|
||||
await runInvoiceCron();
|
||||
|
||||
expect(mockGenerateMonthlyInvoice).toHaveBeenCalledTimes(3);
|
||||
expect(mockRedisEval).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
51
src/billing/cron.ts
Normal file
51
src/billing/cron.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
import { randomUUID } from 'crypto';
|
||||
import redis from '../redis.js';
|
||||
import { getPool } from '../db.js';
|
||||
import { generateMonthlyInvoice } from './invoices.js';
|
||||
|
||||
const LOCK_KEY = 'invoice:cron:lock';
|
||||
const LOCK_TTL_SECONDS = 3600;
|
||||
|
||||
// Only releases the lock if the value matches — prevents one replica from
|
||||
// deleting another's lock after a TTL expiry race.
|
||||
const COMPARE_DELETE_LUA = `
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("del", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`;
|
||||
|
||||
export async function runInvoiceCron(): Promise<void> {
|
||||
const lockValue = randomUUID();
|
||||
|
||||
const acquired = await redis.set(LOCK_KEY, lockValue, { NX: true, EX: LOCK_TTL_SECONDS });
|
||||
if (!acquired) {
|
||||
console.log('[cron] Another process holds the invoice lock — skipping');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const [customers] = await getPool().query<any[]>(
|
||||
"SELECT id FROM customers WHERE active = TRUE AND plan != 'free'"
|
||||
);
|
||||
|
||||
for (const customer of customers) {
|
||||
try {
|
||||
const invoice = await generateMonthlyInvoice(customer.id);
|
||||
if (invoice) {
|
||||
console.log(`[cron] Generated invoice ${invoice.invoice_number} for customer ${customer.id}`);
|
||||
}
|
||||
} catch (err) {
|
||||
// One customer failing must not abort the rest
|
||||
console.error(`[cron] Failed for customer ${customer.id}:`, err);
|
||||
}
|
||||
}
|
||||
|
||||
await redis.eval(COMPARE_DELETE_LUA, { keys: [LOCK_KEY], arguments: [lockValue] });
|
||||
} catch (err) {
|
||||
// Leave the lock in place on unexpected failure — TTL releases it after 1 hour
|
||||
console.error('[cron] Invoice cron fatal error:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
151
src/billing/invoices.test.ts
Normal file
151
src/billing/invoices.test.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import nodemailer from 'nodemailer';
|
||||
|
||||
const mockQuery = vi.fn();
|
||||
vi.mock('../db.js', () => ({
|
||||
getPool: () => ({ query: mockQuery }),
|
||||
}));
|
||||
vi.mock('nodemailer', () => ({
|
||||
default: { createTransport: vi.fn() },
|
||||
}));
|
||||
|
||||
import {
|
||||
createInvoice,
|
||||
generateMonthlyInvoice,
|
||||
getInvoiceByNumber,
|
||||
emailInvoice,
|
||||
type Invoice,
|
||||
} from './invoices.js';
|
||||
|
||||
const baseInvoice: Invoice = {
|
||||
id: 1,
|
||||
customer_id: 'cust-1',
|
||||
invoice_number: 'SMCP-20260501-aabb1122',
|
||||
amount: 25.0,
|
||||
currency: 'USD',
|
||||
status: 'draft',
|
||||
period_start: '2026-04-01',
|
||||
period_end: '2026-04-30',
|
||||
line_items: [{ description: 'email actions', quantity: 500, unit_price: 0.05, amount: 25.0 }],
|
||||
sent_at: null,
|
||||
paid_at: null,
|
||||
created_at: '2026-05-01T00:00:00Z',
|
||||
constructor: { name: 'RowDataPacket' } as any,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
// ── generateInvoiceNumber format ─────────────────────────────────────────────
|
||||
|
||||
describe('generateMonthlyInvoice — billing period', () => {
|
||||
it('queries previous month, not current', async () => {
|
||||
mockQuery.mockResolvedValueOnce([[]]); // usage query returns empty
|
||||
await generateMonthlyInvoice('cust-1');
|
||||
const sql: string = mockQuery.mock.calls[0][0];
|
||||
// Must reference DATE_SUB for the start of the previous month
|
||||
expect(sql).toContain('DATE_SUB');
|
||||
// Must NOT use a plain NOW() as the lower bound (that would be current month)
|
||||
expect(sql).not.toMatch(/>=\s*DATE_FORMAT\(NOW\(\)/);
|
||||
});
|
||||
|
||||
it('returns null when customer has zero usage', async () => {
|
||||
mockQuery.mockResolvedValueOnce([[]]); // no usage rows
|
||||
const result = await generateMonthlyInvoice('cust-1');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('creates invoice for previous month period', async () => {
|
||||
const now = new Date();
|
||||
const expectedStart = new Date(now.getFullYear(), now.getMonth() - 1, 1)
|
||||
.toISOString().slice(0, 10);
|
||||
|
||||
mockQuery
|
||||
.mockResolvedValueOnce([[{ platform: 'email', count: 100 }]]) // usage
|
||||
.mockResolvedValueOnce([[]]) // insert
|
||||
.mockResolvedValueOnce([[{ ...baseInvoice, period_start: expectedStart }]]); // select after insert
|
||||
|
||||
const invoice = await generateMonthlyInvoice('cust-1');
|
||||
expect(invoice?.period_start).toBe(expectedStart);
|
||||
});
|
||||
});
|
||||
|
||||
// ── Invoice number format ─────────────────────────────────────────────────────
|
||||
|
||||
describe('createInvoice — invoice number', () => {
|
||||
it('generates hex suffix, not decimal padded', async () => {
|
||||
mockQuery
|
||||
.mockResolvedValueOnce([[]]) // insert
|
||||
.mockResolvedValueOnce([[baseInvoice]]); // select
|
||||
|
||||
await createInvoice('cust-1', 25, baseInvoice.line_items, '2026-04-01', '2026-04-30');
|
||||
|
||||
const insertSql: string = mockQuery.mock.calls[0][0];
|
||||
const invoiceNumber: string = mockQuery.mock.calls[0][1][1];
|
||||
// Should be SMCP-YYYYMMDD-<8 hex chars>
|
||||
expect(invoiceNumber).toMatch(/^SMCP-\d{8}-[0-9a-f]{8}$/);
|
||||
});
|
||||
});
|
||||
|
||||
// ── Idempotency ───────────────────────────────────────────────────────────────
|
||||
|
||||
describe('createInvoice — idempotency', () => {
|
||||
it('returns existing invoice on uq_customer_period duplicate', async () => {
|
||||
const dupError = Object.assign(new Error("Duplicate entry 'cust-1-2026-04-01' for key 'invoices.uq_customer_period'"), { errno: 1062 });
|
||||
mockQuery
|
||||
.mockRejectedValueOnce(dupError) // INSERT throws 1062
|
||||
.mockResolvedValueOnce([[baseInvoice]]); // SELECT existing
|
||||
|
||||
const invoice = await createInvoice('cust-1', 25, baseInvoice.line_items, '2026-04-01', '2026-04-30');
|
||||
expect(invoice.invoice_number).toBe(baseInvoice.invoice_number);
|
||||
expect(invoice.customer_id).toBe('cust-1');
|
||||
});
|
||||
|
||||
it('re-throws 1062 for other constraints (e.g. duplicate invoice_number)', async () => {
|
||||
const dupError = Object.assign(new Error("Duplicate entry 'SMCP-...' for key 'invoices.invoice_number'"), { errno: 1062 });
|
||||
mockQuery.mockRejectedValueOnce(dupError);
|
||||
|
||||
await expect(createInvoice('cust-1', 25, baseInvoice.line_items, '2026-04-01', '2026-04-30'))
|
||||
.rejects.toThrow('invoice_number');
|
||||
});
|
||||
});
|
||||
|
||||
// ── line_items JSON normalization ─────────────────────────────────────────────
|
||||
|
||||
describe('getInvoiceByNumber — line_items normalization', () => {
|
||||
it('parses line_items when returned as JSON string', async () => {
|
||||
const rawInvoice = {
|
||||
...baseInvoice,
|
||||
line_items: JSON.stringify(baseInvoice.line_items) as any,
|
||||
};
|
||||
mockQuery.mockResolvedValueOnce([[rawInvoice]]);
|
||||
|
||||
const invoice = await getInvoiceByNumber('SMCP-20260501-aabb1122');
|
||||
expect(Array.isArray(invoice?.line_items)).toBe(true);
|
||||
expect(invoice?.line_items[0].description).toBe('email actions');
|
||||
});
|
||||
|
||||
it('leaves line_items unchanged when already an array', async () => {
|
||||
mockQuery.mockResolvedValueOnce([[baseInvoice]]);
|
||||
const invoice = await getInvoiceByNumber('SMCP-20260501-aabb1122');
|
||||
expect(Array.isArray(invoice?.line_items)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
// ── emailInvoice ─────────────────────────────────────────────────────────────
|
||||
|
||||
describe('emailInvoice', () => {
|
||||
it('sends mail with invoice_number in subject and HTML body', async () => {
|
||||
const sendMail = vi.fn().mockResolvedValue({ messageId: '<abc@mail>' });
|
||||
vi.mocked(nodemailer.createTransport).mockReturnValue({ sendMail } as any);
|
||||
|
||||
await emailInvoice(baseInvoice, 'customer@example.com');
|
||||
|
||||
expect(sendMail).toHaveBeenCalledOnce();
|
||||
const mailArgs = sendMail.mock.calls[0][0];
|
||||
expect(mailArgs.to).toBe('customer@example.com');
|
||||
expect(mailArgs.subject).toContain(baseInvoice.invoice_number);
|
||||
expect(mailArgs.html).toContain(baseInvoice.invoice_number);
|
||||
});
|
||||
});
|
||||
@@ -1,3 +1,5 @@
|
||||
import { randomBytes } from 'crypto';
|
||||
import nodemailer from 'nodemailer';
|
||||
import { getPool } from '../db.js';
|
||||
import type { RowDataPacket } from 'mysql2';
|
||||
|
||||
@@ -24,10 +26,16 @@ export interface Invoice extends RowDataPacket {
|
||||
}
|
||||
|
||||
function generateInvoiceNumber(): string {
|
||||
const prefix = 'SMCP';
|
||||
const date = new Date().toISOString().slice(0, 10).replace(/-/g, '');
|
||||
const random = Math.floor(Math.random() * 10000).toString().padStart(4, '0');
|
||||
return `${prefix}-${date}-${random}`;
|
||||
const suffix = randomBytes(4).toString('hex');
|
||||
return `SMCP-${date}-${suffix}`;
|
||||
}
|
||||
|
||||
function normalizeInvoice(invoice: Invoice): Invoice {
|
||||
if (typeof invoice.line_items === 'string') {
|
||||
invoice.line_items = JSON.parse(invoice.line_items);
|
||||
}
|
||||
return invoice;
|
||||
}
|
||||
|
||||
export async function createInvoice(
|
||||
@@ -38,16 +46,27 @@ export async function createInvoice(
|
||||
periodEnd: string
|
||||
): Promise<Invoice> {
|
||||
const invoiceNumber = generateInvoiceNumber();
|
||||
await getPool().query(
|
||||
`INSERT INTO invoices (customer_id, invoice_number, amount, line_items, period_start, period_end)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
[customerId, invoiceNumber, amount, JSON.stringify(lineItems), periodStart, periodEnd]
|
||||
);
|
||||
try {
|
||||
await getPool().query(
|
||||
`INSERT INTO invoices (customer_id, invoice_number, amount, line_items, period_start, period_end)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
[customerId, invoiceNumber, amount, JSON.stringify(lineItems), periodStart, periodEnd]
|
||||
);
|
||||
} catch (err: any) {
|
||||
if (err.errno === 1062 && err.message.includes('uq_customer_period')) {
|
||||
const [rows] = await getPool().query<Invoice[]>(
|
||||
'SELECT * FROM invoices WHERE customer_id = ? AND period_start = ?',
|
||||
[customerId, periodStart]
|
||||
);
|
||||
return normalizeInvoice(rows[0]);
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
const [rows] = await getPool().query<Invoice[]>(
|
||||
'SELECT * FROM invoices WHERE invoice_number = ?',
|
||||
[invoiceNumber]
|
||||
);
|
||||
return rows[0];
|
||||
return normalizeInvoice(rows[0]);
|
||||
}
|
||||
|
||||
export async function getCustomerInvoices(customerId: string): Promise<Invoice[]> {
|
||||
@@ -55,7 +74,7 @@ export async function getCustomerInvoices(customerId: string): Promise<Invoice[]
|
||||
'SELECT * FROM invoices WHERE customer_id = ? ORDER BY created_at DESC',
|
||||
[customerId]
|
||||
);
|
||||
return rows;
|
||||
return rows.map(normalizeInvoice);
|
||||
}
|
||||
|
||||
export async function getInvoiceByNumber(invoiceNumber: string): Promise<Invoice | null> {
|
||||
@@ -63,7 +82,7 @@ export async function getInvoiceByNumber(invoiceNumber: string): Promise<Invoice
|
||||
'SELECT * FROM invoices WHERE invoice_number = ?',
|
||||
[invoiceNumber]
|
||||
);
|
||||
return rows[0] ?? null;
|
||||
return rows[0] ? normalizeInvoice(rows[0]) : null;
|
||||
}
|
||||
|
||||
export async function markInvoiceSent(invoiceNumber: string): Promise<void> {
|
||||
@@ -80,12 +99,45 @@ export async function markInvoicePaid(invoiceNumber: string): Promise<void> {
|
||||
);
|
||||
}
|
||||
|
||||
export async function emailInvoice(invoice: Invoice, toEmail: string): Promise<void> {
|
||||
const transport = nodemailer.createTransport({
|
||||
host: process.env.FETCHERPAY_SMTP_HOST ?? 'mail.fetcherpay.com',
|
||||
port: parseInt(process.env.FETCHERPAY_SMTP_PORT ?? '30587', 10),
|
||||
secure: false,
|
||||
auth: {
|
||||
user: process.env.BILLING_EMAIL ?? process.env.FETCHERPAY_EMAIL,
|
||||
pass: process.env.BILLING_PASSWORD ?? process.env.FETCHERPAY_PASSWORD,
|
||||
},
|
||||
tls: { rejectUnauthorized: false },
|
||||
});
|
||||
await transport.sendMail({
|
||||
from: process.env.BILLING_FROM ?? 'billing@squaremcp.com',
|
||||
to: toEmail,
|
||||
subject: `Invoice ${invoice.invoice_number} from SquareMCP`,
|
||||
html: `
|
||||
<h1>Invoice ${invoice.invoice_number}</h1>
|
||||
<p>Billing period: ${invoice.period_start} – ${invoice.period_end}</p>
|
||||
<p>Amount due: $${invoice.amount} ${invoice.currency}</p>
|
||||
<table>
|
||||
<tr><th>Description</th><th>Qty</th><th>Unit price</th><th>Amount</th></tr>
|
||||
${invoice.line_items.map((li) =>
|
||||
`<tr><td>${li.description}</td><td>${li.quantity}</td><td>$${li.unit_price}</td><td>$${li.amount}</td></tr>`
|
||||
).join('')}
|
||||
</table>
|
||||
`,
|
||||
});
|
||||
}
|
||||
|
||||
export async function generateMonthlyInvoice(customerId: string): Promise<Invoice | null> {
|
||||
const now = new Date();
|
||||
const prevMonthStart = new Date(now.getFullYear(), now.getMonth() - 1, 1);
|
||||
const prevMonthEnd = new Date(now.getFullYear(), now.getMonth(), 0);
|
||||
|
||||
const [usageRows] = await getPool().query<any[]>(
|
||||
`SELECT platform, COUNT(*) as count FROM usage_logs
|
||||
WHERE customer_id = ?
|
||||
AND created_at >= DATE_FORMAT(NOW(), '%Y-%m-01')
|
||||
AND created_at < DATE_FORMAT(DATE_ADD(NOW(), INTERVAL 1 MONTH), '%Y-%m-01')
|
||||
AND created_at >= DATE_FORMAT(DATE_SUB(NOW(), INTERVAL 1 MONTH), '%Y-%m-01')
|
||||
AND created_at < DATE_FORMAT(NOW(), '%Y-%m-01')
|
||||
GROUP BY platform`,
|
||||
[customerId]
|
||||
);
|
||||
@@ -97,7 +149,7 @@ export async function generateMonthlyInvoice(customerId: string): Promise<Invoic
|
||||
|
||||
for (const row of usageRows) {
|
||||
const qty = row.count;
|
||||
const unitPrice = 0.05; // $0.05 per action
|
||||
const unitPrice = 0.05;
|
||||
const amount = qty * unitPrice;
|
||||
total += amount;
|
||||
lineItems.push({
|
||||
@@ -108,15 +160,11 @@ export async function generateMonthlyInvoice(customerId: string): Promise<Invoic
|
||||
});
|
||||
}
|
||||
|
||||
const now = new Date();
|
||||
const start = new Date(now.getFullYear(), now.getMonth(), 1);
|
||||
const end = new Date(now.getFullYear(), now.getMonth() + 1, 0);
|
||||
|
||||
return createInvoice(
|
||||
customerId,
|
||||
parseFloat(total.toFixed(2)),
|
||||
lineItems,
|
||||
start.toISOString().slice(0, 10),
|
||||
end.toISOString().slice(0, 10)
|
||||
prevMonthStart.toISOString().slice(0, 10),
|
||||
prevMonthEnd.toISOString().slice(0, 10)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import { createClient } from 'redis';
|
||||
import { RowDataPacket } from 'mysql2';
|
||||
import { getPool } from '../db.js';
|
||||
import { getCredential, Platform, PlatformCredentials } from '../multitenancy/credential-store.js';
|
||||
import type { PlanKey } from './plans.js';
|
||||
import type { Request, Response, NextFunction } from 'express';
|
||||
import { verifyJWT } from '../auth.js';
|
||||
|
||||
const redis = createClient({ url: process.env.REDIS_URL });
|
||||
redis.connect().catch((err) => console.error('[billing] Redis connect error:', err));
|
||||
import redis from '../redis.js';
|
||||
|
||||
export interface Customer {
|
||||
id: string;
|
||||
|
||||
@@ -20,11 +20,11 @@ export const PLANS: Record<PlanKey, Plan> = {
|
||||
growth: {
|
||||
name: 'Growth',
|
||||
monthlyCallLimit: 10000,
|
||||
platforms: ['email', 'obsidian', 'whatsapp', 'telegram', 'discord', 'instagram', 'linkedin', 'twitter'],
|
||||
platforms: ['email', 'obsidian', 'whatsapp', 'telegram', 'discord', 'instagram', 'linkedin', 'twitter', 'tiktok', 'facebook', 'snapchat'],
|
||||
},
|
||||
enterprise: {
|
||||
name: 'Enterprise',
|
||||
monthlyCallLimit: -1,
|
||||
platforms: ['email', 'obsidian', 'whatsapp', 'telegram', 'discord', 'instagram', 'linkedin', 'twitter'],
|
||||
platforms: ['email', 'obsidian', 'whatsapp', 'telegram', 'discord', 'instagram', 'linkedin', 'twitter', 'tiktok', 'facebook', 'snapchat'],
|
||||
},
|
||||
};
|
||||
|
||||
47
src/billing/usage.test.ts
Normal file
47
src/billing/usage.test.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { checkLimit } from './usage.js';
|
||||
|
||||
vi.mock('../db.js', () => ({
|
||||
getPool: vi.fn(() => ({
|
||||
query: vi.fn().mockResolvedValue([[{ count: 50 }]]),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('checkLimit', () => {
|
||||
it('always allows enterprise plan regardless of usage', async () => {
|
||||
const result = await checkLimit('cust-1', 'enterprise');
|
||||
expect(result).toEqual({ allowed: true, limit: -1, used: 0 });
|
||||
});
|
||||
|
||||
it('allows when usage is under limit', async () => {
|
||||
const { getPool } = await import('../db.js');
|
||||
vi.mocked(getPool).mockReturnValue({ query: vi.fn().mockResolvedValue([[{ count: 50 }]]) } as any);
|
||||
const result = await checkLimit('cust-1', 'growth');
|
||||
expect(result.allowed).toBe(true);
|
||||
expect(result.used).toBe(50);
|
||||
expect(result.limit).toBe(10000);
|
||||
});
|
||||
|
||||
it('blocks when usage equals limit', async () => {
|
||||
const { getPool } = await import('../db.js');
|
||||
vi.mocked(getPool).mockReturnValue({ query: vi.fn().mockResolvedValue([[{ count: 1000 }]]) } as any);
|
||||
const result = await checkLimit('cust-1', 'starter');
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.used).toBe(1000);
|
||||
});
|
||||
|
||||
it('blocks when usage exceeds limit', async () => {
|
||||
const { getPool } = await import('../db.js');
|
||||
vi.mocked(getPool).mockReturnValue({ query: vi.fn().mockResolvedValue([[{ count: 1001 }]]) } as any);
|
||||
const result = await checkLimit('cust-1', 'free');
|
||||
expect(result.allowed).toBe(false);
|
||||
});
|
||||
|
||||
it('does not query DB for enterprise plan', async () => {
|
||||
const { getPool } = await import('../db.js');
|
||||
const querySpy = vi.fn();
|
||||
vi.mocked(getPool).mockReturnValue({ query: querySpy } as any);
|
||||
await checkLimit('cust-1', 'enterprise');
|
||||
expect(querySpy).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
46
src/db.ts
46
src/db.ts
@@ -14,23 +14,29 @@ async function ensureColumn(
|
||||
definition: string
|
||||
): Promise<void> {
|
||||
const [rows] = await db.execute<mysql.RowDataPacket[]>(
|
||||
`
|
||||
SELECT COLUMN_NAME
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
AND TABLE_NAME = ?
|
||||
AND COLUMN_NAME = ?
|
||||
`,
|
||||
`SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ? AND COLUMN_NAME = ?`,
|
||||
[tableName, columnName]
|
||||
);
|
||||
|
||||
if (Array.isArray(rows) && rows.length > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (Array.isArray(rows) && rows.length > 0) return;
|
||||
await db.execute(`ALTER TABLE \`${tableName}\` ADD COLUMN \`${columnName}\` ${definition}`);
|
||||
}
|
||||
|
||||
async function ensureIndex(
|
||||
db: mysql.PoolConnection,
|
||||
tableName: string,
|
||||
indexName: string,
|
||||
definition: string
|
||||
): Promise<void> {
|
||||
const [rows] = await db.execute<mysql.RowDataPacket[]>(
|
||||
`SELECT INDEX_NAME FROM INFORMATION_SCHEMA.STATISTICS
|
||||
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ? AND INDEX_NAME = ?`,
|
||||
[tableName, indexName]
|
||||
);
|
||||
if (Array.isArray(rows) && rows.length > 0) return;
|
||||
await db.execute(`ALTER TABLE \`${tableName}\` ADD ${definition}`);
|
||||
}
|
||||
|
||||
export function getPool(): mysql.Pool {
|
||||
if (!pool) {
|
||||
throw new Error('Database pool not initialized. Call initDatabase() first.');
|
||||
@@ -93,7 +99,11 @@ export async function initDatabase(): Promise<void> {
|
||||
await ensureColumn(db, 'oauth_auth_codes', 'scope', 'TEXT NULL');
|
||||
await ensureColumn(db, 'oauth_auth_codes', 'code_challenge', 'TEXT NULL');
|
||||
await ensureColumn(db, 'oauth_auth_codes', 'code_challenge_method', 'VARCHAR(20) NULL');
|
||||
await ensureColumn(db, 'oauth_auth_codes', 'customer_id', 'VARCHAR(255) NULL');
|
||||
await ensureColumn(db, 'oauth_tokens', 'customer_id', 'VARCHAR(255) NULL');
|
||||
await ensureColumn(db, 'customers', 'password_hash', 'VARCHAR(255) NULL');
|
||||
await ensureColumn(db, 'customers', 'webhook_url', 'VARCHAR(512) NULL');
|
||||
await ensureColumn(db, 'customers', 'webhook_secret', 'VARCHAR(64) NULL');
|
||||
|
||||
await db.execute(`
|
||||
CREATE TABLE IF NOT EXISTS oauth_tokens (
|
||||
@@ -158,6 +168,18 @@ export async function initDatabase(): Promise<void> {
|
||||
await ensureColumn(db, 'customers', 'role', "ENUM('user','admin') DEFAULT 'user'");
|
||||
await ensureColumn(db, 'customers', 'reset_token', 'VARCHAR(255) NULL');
|
||||
await ensureColumn(db, 'customers', 'reset_expires_at', 'TIMESTAMP NULL');
|
||||
|
||||
// Remove duplicate (customer_id, period_start) rows before adding unique constraint
|
||||
// Keeps the earliest invoice (lowest id) for each customer+period pair
|
||||
await db.execute(`
|
||||
DELETE i1 FROM invoices i1
|
||||
INNER JOIN invoices i2
|
||||
ON i1.customer_id = i2.customer_id
|
||||
AND i1.period_start = i2.period_start
|
||||
AND i1.id > i2.id
|
||||
`);
|
||||
await ensureIndex(db, 'invoices', 'uq_customer_period',
|
||||
'UNIQUE KEY uq_customer_period (customer_id, period_start)');
|
||||
} finally {
|
||||
db.release();
|
||||
}
|
||||
|
||||
117
src/imap.ts
117
src/imap.ts
@@ -1,6 +1,8 @@
|
||||
import { ImapFlow } from 'imapflow';
|
||||
import type { EmailCredentials } from './multitenancy/credential-store.js';
|
||||
|
||||
export type Account = 'yahoo' | 'fetcherpay' | 'garfield' | 'sales' | 'leads' | 'founder' | 'gmail';
|
||||
export type EmailCtx = Account | EmailCredentials;
|
||||
|
||||
const FETCHERPAY_IMAP_HOST = process.env['FETCHERPAY_IMAP_HOST'] ?? 'mail.fetcherpay.com';
|
||||
const FETCHERPAY_IMAP_PORT = parseInt(process.env['FETCHERPAY_IMAP_PORT'] ?? '30993');
|
||||
@@ -15,41 +17,26 @@ function fetcherpayImapConfig(user: string, pass: string) {
|
||||
};
|
||||
}
|
||||
|
||||
function getConfig(account: Account = 'yahoo') {
|
||||
function getEnvConfig(account: Account = 'yahoo') {
|
||||
switch (account) {
|
||||
case 'fetcherpay':
|
||||
return fetcherpayImapConfig(
|
||||
process.env['FETCHERPAY_EMAIL'] as string,
|
||||
process.env['FETCHERPAY_PASSWORD'] as string,
|
||||
);
|
||||
return fetcherpayImapConfig(process.env['FETCHERPAY_EMAIL']!, process.env['FETCHERPAY_PASSWORD']!);
|
||||
case 'garfield':
|
||||
return fetcherpayImapConfig(
|
||||
process.env['GARFIELD_EMAIL'] as string,
|
||||
process.env['GARFIELD_PASSWORD'] as string,
|
||||
);
|
||||
return fetcherpayImapConfig(process.env['GARFIELD_EMAIL']!, process.env['GARFIELD_PASSWORD']!);
|
||||
case 'sales':
|
||||
return fetcherpayImapConfig(
|
||||
process.env['SALES_EMAIL'] as string,
|
||||
process.env['SALES_PASSWORD'] as string,
|
||||
);
|
||||
return fetcherpayImapConfig(process.env['SALES_EMAIL']!, process.env['SALES_PASSWORD']!);
|
||||
case 'leads':
|
||||
return fetcherpayImapConfig(
|
||||
process.env['LEADS_EMAIL'] as string,
|
||||
process.env['LEADS_PASSWORD'] as string,
|
||||
);
|
||||
return fetcherpayImapConfig(process.env['LEADS_EMAIL']!, process.env['LEADS_PASSWORD']!);
|
||||
case 'founder':
|
||||
return fetcherpayImapConfig(
|
||||
process.env['FOUNDER_EMAIL'] as string,
|
||||
process.env['FOUNDER_PASSWORD'] as string,
|
||||
);
|
||||
return fetcherpayImapConfig(process.env['FOUNDER_EMAIL']!, process.env['FOUNDER_PASSWORD']!);
|
||||
case 'gmail':
|
||||
return {
|
||||
host: 'imap.gmail.com',
|
||||
port: 993,
|
||||
secure: true,
|
||||
auth: {
|
||||
user: process.env['GMAIL_EMAIL'] as string,
|
||||
pass: process.env['GMAIL_APP_PASSWORD'] as string,
|
||||
user: process.env['GMAIL_EMAIL']!,
|
||||
pass: process.env['GMAIL_APP_PASSWORD']!,
|
||||
},
|
||||
};
|
||||
default:
|
||||
@@ -58,15 +45,33 @@ function getConfig(account: Account = 'yahoo') {
|
||||
port: 993,
|
||||
secure: true,
|
||||
auth: {
|
||||
user: process.env['YAHOO_EMAIL'] as string,
|
||||
pass: process.env['YAHOO_APP_PASSWORD'] as string,
|
||||
user: process.env['YAHOO_EMAIL']!,
|
||||
pass: process.env['YAHOO_APP_PASSWORD']!,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async function withClient<T>(account: Account, fn: (client: ImapFlow) => Promise<T>): Promise<T> {
|
||||
const client = new ImapFlow(getConfig(account));
|
||||
function resolveImapConfig(ctx: EmailCtx) {
|
||||
if (typeof ctx === 'object') {
|
||||
return {
|
||||
host: ctx.host,
|
||||
port: ctx.port,
|
||||
secure: ctx.port === 993,
|
||||
auth: { user: ctx.user, pass: ctx.password },
|
||||
tls: { rejectUnauthorized: false },
|
||||
};
|
||||
}
|
||||
return getEnvConfig(ctx);
|
||||
}
|
||||
|
||||
function isGmail(ctx: EmailCtx): boolean {
|
||||
if (typeof ctx === 'string') return ctx === 'gmail';
|
||||
return ctx.host === 'imap.gmail.com';
|
||||
}
|
||||
|
||||
async function withClient<T>(ctx: EmailCtx, fn: (client: ImapFlow) => Promise<T>): Promise<T> {
|
||||
const client = new ImapFlow(resolveImapConfig(ctx));
|
||||
await client.connect();
|
||||
try {
|
||||
return await fn(client);
|
||||
@@ -101,7 +106,6 @@ function parseSearchCriteria(query: string): object {
|
||||
const trimmed = query.trim();
|
||||
if (!trimmed) return { all: true };
|
||||
|
||||
// Parse quoted and unquoted tokens like from:x, subject:y, to:z
|
||||
const tokens: { key: string; value: string }[] = [];
|
||||
const regex = /(\w+):("([^"]*)"|([^\s]+))/g;
|
||||
let m: RegExpExecArray | null;
|
||||
@@ -123,15 +127,9 @@ function parseSearchCriteria(query: string): object {
|
||||
const parts: object[] = [];
|
||||
for (const t of tokens) {
|
||||
switch (t.key) {
|
||||
case 'from':
|
||||
parts.push({ from: t.value });
|
||||
break;
|
||||
case 'subject':
|
||||
parts.push({ subject: t.value });
|
||||
break;
|
||||
case 'to':
|
||||
parts.push({ to: t.value });
|
||||
break;
|
||||
case 'from': parts.push({ from: t.value }); break;
|
||||
case 'subject': parts.push({ subject: t.value }); break;
|
||||
case 'to': parts.push({ to: t.value }); break;
|
||||
case 'after':
|
||||
case 'since': {
|
||||
const d = new Date(t.value);
|
||||
@@ -159,11 +157,11 @@ async function searchInFolder(
|
||||
folder: string,
|
||||
query: string,
|
||||
maxResults: number,
|
||||
account: Account
|
||||
ctx: EmailCtx
|
||||
): Promise<MessageSummary[]> {
|
||||
await client.mailboxOpen(folder);
|
||||
|
||||
const criteria = account === 'gmail'
|
||||
const criteria = isGmail(ctx)
|
||||
? { gmailraw: query }
|
||||
: parseSearchCriteria(query);
|
||||
|
||||
@@ -200,28 +198,25 @@ async function searchInFolder(
|
||||
export async function searchMessages(
|
||||
query: string,
|
||||
maxResults = 20,
|
||||
account: Account = 'yahoo',
|
||||
ctx: EmailCtx = 'yahoo',
|
||||
folder?: string
|
||||
): Promise<MessageSummary[]> {
|
||||
return withClient(account, async (client) => {
|
||||
return withClient(ctx, async (client) => {
|
||||
const foldersToSearch: string[] = [];
|
||||
|
||||
if (folder) {
|
||||
foldersToSearch.push(folder);
|
||||
} else if (account === 'gmail') {
|
||||
foldersToSearch.push('INBOX');
|
||||
} else {
|
||||
foldersToSearch.push('INBOX');
|
||||
}
|
||||
|
||||
for (const f of foldersToSearch) {
|
||||
const results = await searchInFolder(client, f, query, maxResults, account);
|
||||
const results = await searchInFolder(client, f, query, maxResults, ctx);
|
||||
if (results.length > 0) return results;
|
||||
}
|
||||
|
||||
// Fallback for Gmail: search All Mail if INBOX was empty
|
||||
if (account === 'gmail' && !folder) {
|
||||
const allMailResults = await searchInFolder(client, '[Gmail]/All Mail', query, maxResults, account);
|
||||
if (isGmail(ctx) && !folder) {
|
||||
const allMailResults = await searchInFolder(client, '[Gmail]/All Mail', query, maxResults, ctx);
|
||||
if (allMailResults.length > 0) return allMailResults;
|
||||
}
|
||||
|
||||
@@ -229,11 +224,10 @@ export async function searchMessages(
|
||||
});
|
||||
}
|
||||
|
||||
export async function readMessage(uid: number, account: Account = 'yahoo', folder = 'INBOX'): Promise<FullMessage> {
|
||||
return withClient(account, async (client) => {
|
||||
console.log(`[imap] readMessage uid=${uid} account=${account} folder=${folder}`);
|
||||
export async function readMessage(uid: number, ctx: EmailCtx = 'yahoo', folder = 'INBOX'): Promise<FullMessage> {
|
||||
return withClient(ctx, async (client) => {
|
||||
console.log(`[imap] readMessage uid=${uid} folder=${folder}`);
|
||||
await client.mailboxOpen(folder);
|
||||
console.log(`[imap] mailbox opened, fetching uid=${uid}`);
|
||||
|
||||
let result: FullMessage | null = null;
|
||||
|
||||
@@ -243,7 +237,6 @@ export async function readMessage(uid: number, account: Account = 'yahoo', folde
|
||||
bodyParts: ['TEXT'],
|
||||
}, { uid: true })) {
|
||||
const env = msg.envelope;
|
||||
console.log(`[imap] got msg uid=${msg.uid} subject="${env?.subject}"`);
|
||||
|
||||
const bpKeys = msg.bodyParts ? [...msg.bodyParts.keys()] : [];
|
||||
console.log(`[imap] bodyParts keys:`, JSON.stringify(bpKeys));
|
||||
@@ -252,10 +245,8 @@ export async function readMessage(uid: number, account: Account = 'yahoo', folde
|
||||
msg.bodyParts?.get('text') ??
|
||||
msg.bodyParts?.get('TEXT') ??
|
||||
msg.bodyParts?.get('1');
|
||||
console.log(`[imap] textBuf length=${textBuf ? textBuf.length : 'null'}`);
|
||||
|
||||
const rawBody = textBuf ? textBuf.toString('utf-8') : '';
|
||||
|
||||
const body = rawBody
|
||||
.replace(/<style[^>]*>[\s\S]*?<\/style>/gi, '')
|
||||
.replace(/<script[^>]*>[\s\S]*?<\/script>/gi, '')
|
||||
@@ -264,8 +255,6 @@ export async function readMessage(uid: number, account: Account = 'yahoo', folde
|
||||
.trim()
|
||||
.slice(0, 10000);
|
||||
|
||||
console.log(`[imap] body length after strip=${body.length}`);
|
||||
|
||||
result = {
|
||||
uid: msg.uid,
|
||||
messageId: env?.messageId ?? '',
|
||||
@@ -282,17 +271,17 @@ export async function readMessage(uid: number, account: Account = 'yahoo', folde
|
||||
|
||||
if (!result) throw new Error(`Message UID ${uid} not found`);
|
||||
|
||||
// Mark as seen AFTER the fetch loop fully completes — calling messageFlagsAdd
|
||||
// inside the for-await loop deadlocks because the FETCH command is still active.
|
||||
console.log(`[imap] marking uid=${uid} as seen`);
|
||||
await client.messageFlagsAdd([uid], ['\\Seen'], { uid: true });
|
||||
|
||||
console.log(`[imap] readMessage done uid=${uid}`);
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
||||
export async function getProfile(account: Account = 'yahoo'): Promise<{ email: string; name: string; account: string }> {
|
||||
export async function getProfile(ctx: EmailCtx = 'yahoo'): Promise<{ email: string; name: string; account: string }> {
|
||||
if (typeof ctx === 'object') {
|
||||
return { email: ctx.user, name: ctx.user.split('@')[0], account: 'custom' };
|
||||
}
|
||||
const emailMap: Record<Account, string> = {
|
||||
yahoo: process.env['YAHOO_EMAIL'] ?? '',
|
||||
fetcherpay: process.env['FETCHERPAY_EMAIL'] ?? '',
|
||||
@@ -302,12 +291,12 @@ export async function getProfile(account: Account = 'yahoo'): Promise<{ email: s
|
||||
founder: process.env['FOUNDER_EMAIL'] ?? '',
|
||||
gmail: process.env['GMAIL_EMAIL'] ?? '',
|
||||
};
|
||||
const email = emailMap[account] ?? '';
|
||||
return { email, name: email.split('@')[0], account };
|
||||
const email = emailMap[ctx] ?? '';
|
||||
return { email, name: email.split('@')[0], account: ctx };
|
||||
}
|
||||
|
||||
export async function listFolders(account: Account = 'yahoo'): Promise<string[]> {
|
||||
return withClient(account, async (client) => {
|
||||
export async function listFolders(ctx: EmailCtx = 'yahoo'): Promise<string[]> {
|
||||
return withClient(ctx, async (client) => {
|
||||
const mailboxes = await client.list();
|
||||
return mailboxes.map((m) => m.path);
|
||||
});
|
||||
|
||||
308
src/index.ts
308
src/index.ts
@@ -1,4 +1,5 @@
|
||||
import 'dotenv/config';
|
||||
import crypto from 'crypto';
|
||||
import express from 'express';
|
||||
import cors from 'cors';
|
||||
import cookieParser from 'cookie-parser';
|
||||
@@ -11,7 +12,7 @@ import {
|
||||
ListResourcesRequestSchema,
|
||||
isInitializeRequest,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { tools, handleToolCall } from './tools.js';
|
||||
import { tools, handleToolCall, stripAccountParam } from './tools.js';
|
||||
import { getManifest, getOpenApiSpec, getOpenApiSpecMail, getOpenApiSpecSocial } from './manifest.js';
|
||||
import { routeWhatsAppWebhook, registerWhatsAppNumber, type RoutedWebhookEvent } from './multitenancy/webhook-router.js';
|
||||
import { storeCredential, type Platform } from './multitenancy/credential-store.js';
|
||||
@@ -22,12 +23,18 @@ import {
|
||||
createAuthCode,
|
||||
exchangeCodeForToken,
|
||||
validateAccessToken,
|
||||
getTokenCustomer,
|
||||
getAuthorizeHtml,
|
||||
isValidRedirectUri,
|
||||
ensureOAuthAppRegistered,
|
||||
} from './oauth.js';
|
||||
import { initDatabase, getPool } from './db.js';
|
||||
import { hashPassword, verifyPassword, signJWT, verifyJWT, findCustomerByEmail, createCustomer, setResetToken, findCustomerByResetToken, clearResetToken, updatePassword } from './auth.js';
|
||||
import { recordUsage, getMonthlyUsage, getUsageBreakdown, checkLimit } from './billing/usage.js';
|
||||
import { getCustomerInvoices, getInvoiceByNumber, markInvoiceSent, markInvoicePaid, generateMonthlyInvoice } from './billing/invoices.js';
|
||||
import { getAllPlatformHealth } from './multitenancy/platform-health.js';
|
||||
import { deliverWebhook, isValidWebhookUrl } from './webhooks/delivery.js';
|
||||
import redis from './redis.js';
|
||||
|
||||
const app = express();
|
||||
app.use(cookieParser());
|
||||
@@ -320,9 +327,20 @@ async function requireAuth(req: express.Request, res: express.Response, next: ex
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Check OAuth Bearer token
|
||||
// 3. Check OAuth Bearer token — resolve to customer when token has customer_id binding
|
||||
const bearerToken = extractBearerToken(req);
|
||||
if (bearerToken && await validateAccessToken(bearerToken)) return next();
|
||||
if (bearerToken) {
|
||||
const tokenInfo = await getTokenCustomer(bearerToken);
|
||||
if (tokenInfo) {
|
||||
const customer = await resolveCustomerById(tokenInfo.customerId);
|
||||
if (customer && customer.active) {
|
||||
(req as express.Request & { customer?: Customer }).customer = customer;
|
||||
return next();
|
||||
}
|
||||
}
|
||||
// Fall back to legacy tokens (no customer_id) — validate presence only
|
||||
if (await validateAccessToken(bearerToken)) return next();
|
||||
}
|
||||
|
||||
// 4. Check JWT session cookie (web app auth)
|
||||
const jwtCookie = req.cookies?.session;
|
||||
@@ -364,17 +382,19 @@ app.use((req, res, next) => {
|
||||
next();
|
||||
});
|
||||
|
||||
function createMcpServer() {
|
||||
function createMcpServer(customer?: Customer) {
|
||||
const server = new Server(
|
||||
{ name: 'hermes', version: '1.0.0' },
|
||||
{ capabilities: { tools: {}, resources: {} } }
|
||||
);
|
||||
server.setRequestHandler(ListToolsRequestSchema, async () => ({ tools }));
|
||||
const visibleTools = customer ? tools.map(stripAccountParam) : tools;
|
||||
server.setRequestHandler(ListToolsRequestSchema, async () => ({ tools: visibleTools }));
|
||||
server.setRequestHandler(ListResourcesRequestSchema, async () => ({ resources: [] }));
|
||||
server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
||||
return handleToolCall(
|
||||
request.params.name,
|
||||
(request.params.arguments ?? {}) as Record<string, unknown>
|
||||
(request.params.arguments ?? {}) as Record<string, unknown>,
|
||||
customer
|
||||
);
|
||||
});
|
||||
return server;
|
||||
@@ -421,6 +441,26 @@ app.get('/oauth/authorize', async (req, res) => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isValidRedirectUri(redirectUri, client.redirect_uris)) {
|
||||
res.status(400).send('redirect_uri not registered for this client');
|
||||
return;
|
||||
}
|
||||
|
||||
// Require authenticated SquareMCP session to show the consent page
|
||||
const jwtCookie = req.cookies?.session;
|
||||
if (!jwtCookie) {
|
||||
const returnTo = encodeURIComponent(req.originalUrl);
|
||||
res.redirect(`/login?return_to=${returnTo}`);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
verifyJWT(jwtCookie);
|
||||
} catch {
|
||||
const returnTo = encodeURIComponent(req.originalUrl);
|
||||
res.redirect(`/login?return_to=${returnTo}`);
|
||||
return;
|
||||
}
|
||||
|
||||
res.setHeader('Content-Type', 'text/html');
|
||||
res.send(getAuthorizeHtml({
|
||||
client_id: clientId,
|
||||
@@ -452,6 +492,11 @@ app.post('/oauth/authorize', async (req, res) => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isValidRedirectUri(redirectUri, client.redirect_uris)) {
|
||||
res.status(400).send('redirect_uri not registered for this client');
|
||||
return;
|
||||
}
|
||||
|
||||
if (action !== 'allow') {
|
||||
const url = new URL(redirectUri);
|
||||
url.searchParams.set('error', 'access_denied');
|
||||
@@ -461,7 +506,19 @@ app.post('/oauth/authorize', async (req, res) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const code = await createAuthCode(clientId, redirectUri, scope, codeChallenge, codeChallengeMethod);
|
||||
// Bind the auth code to the authenticated customer if present
|
||||
let customerId: string | undefined;
|
||||
const jwtCookie = req.cookies?.session;
|
||||
if (jwtCookie) {
|
||||
try {
|
||||
const payload = verifyJWT(jwtCookie);
|
||||
customerId = payload.sub;
|
||||
} catch {
|
||||
// no binding — legacy flow
|
||||
}
|
||||
}
|
||||
|
||||
const code = await createAuthCode(clientId, redirectUri, scope, codeChallenge, codeChallengeMethod, customerId);
|
||||
const url = new URL(redirectUri);
|
||||
url.searchParams.set('code', code.code);
|
||||
if (state) url.searchParams.set('state', state);
|
||||
@@ -502,6 +559,121 @@ app.post('/oauth/token', async (req, res) => {
|
||||
});
|
||||
});
|
||||
|
||||
// ── DCR browser flow ────────────────────────────────────────────────────────
|
||||
|
||||
// Kick off the "Connect MCP Client" browser flow — server redirects to consent page
|
||||
// so the client_id never needs to be exposed in the frontend JS.
|
||||
app.get('/oauth/connect-mcp', (req, res) => {
|
||||
const clientId = process.env.OAUTH_CLIENT_ID;
|
||||
if (!clientId) {
|
||||
res.status(503).send('MCP OAuth app not configured (OAUTH_CLIENT_ID missing)');
|
||||
return;
|
||||
}
|
||||
const callbackUrl = `${SERVER_URL}/oauth/mcp-callback`;
|
||||
const params = new URLSearchParams({
|
||||
client_id: clientId,
|
||||
redirect_uri: callbackUrl,
|
||||
response_type: 'code',
|
||||
scope: 'mcp',
|
||||
});
|
||||
res.redirect(`/oauth/authorize?${params}`);
|
||||
});
|
||||
|
||||
// Callback — exchange code for token and render the config snippet page
|
||||
app.get('/oauth/mcp-callback', async (req, res) => {
|
||||
const code = req.query.code as string | undefined;
|
||||
const error = req.query.error as string | undefined;
|
||||
|
||||
if (error || !code) {
|
||||
res.status(400).send(renderMcpCallbackHtml({ error: error || 'Missing authorization code' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const clientId = process.env.OAUTH_CLIENT_ID;
|
||||
const clientSecret = process.env.OAUTH_CLIENT_SECRET;
|
||||
if (!clientId || !clientSecret) {
|
||||
res.status(503).send(renderMcpCallbackHtml({ error: 'Server misconfiguration — OAUTH_CLIENT_ID/SECRET missing' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const callbackUrl = `${SERVER_URL}/oauth/mcp-callback`;
|
||||
const token = await exchangeCodeForToken(clientId, clientSecret, code, callbackUrl);
|
||||
if (!token) {
|
||||
res.status(400).send(renderMcpCallbackHtml({ error: 'Token exchange failed — code may be expired or already used' }));
|
||||
return;
|
||||
}
|
||||
|
||||
res.setHeader('Content-Type', 'text/html');
|
||||
res.send(renderMcpCallbackHtml({ token: token.access_token, serverUrl: SERVER_URL }));
|
||||
});
|
||||
|
||||
function renderMcpCallbackHtml(opts: { token?: string; serverUrl?: string; error?: string }): string {
|
||||
if (opts.error) {
|
||||
return `<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"><title>Connection failed</title>
|
||||
<style>body{font-family:system-ui,sans-serif;background:#0f0f10;color:#e5e5e5;display:flex;justify-content:center;align-items:center;min-height:100vh;margin:0}
|
||||
.card{background:#1a1a1b;border:1px solid #2a2a2b;border-radius:12px;padding:32px;max-width:520px;width:100%}
|
||||
h1{color:#dc2626;margin:0 0 12px}p{color:#888;margin:0}</style></head>
|
||||
<body><div class="card"><h1>Connection failed</h1><p>${opts.error}</p></div></body></html>`;
|
||||
}
|
||||
|
||||
const { token, serverUrl } = opts;
|
||||
const claudeConfig = JSON.stringify({
|
||||
mcpServers: { 'hermes-mcp': { type: 'http', url: `${serverUrl}/mcp`, headers: { Authorization: `Bearer ${token}` } } }
|
||||
}, null, 2);
|
||||
const codexConfig = JSON.stringify({
|
||||
mcpServers: { 'hermes-mcp': { type: 'http', url: `${serverUrl}/mcp`, headers: { Authorization: `Bearer ${token}` } } }
|
||||
}, null, 2);
|
||||
|
||||
const esc = (s: string) => s.replace(/&/g, '&').replace(/</g, '<').replace(/>/g, '>');
|
||||
|
||||
return `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>MCP Client Connected — SquareMCP</title>
|
||||
<style>
|
||||
body{font-family:system-ui,sans-serif;background:#0f0f10;color:#e5e5e5;margin:0;padding:24px}
|
||||
.card{background:#1a1a1b;border:1px solid #2a2a2b;border-radius:12px;padding:32px;max-width:680px;margin:0 auto}
|
||||
h1{font-size:22px;margin:0 0 8px;color:#10a37f}
|
||||
.subtitle{color:#888;margin:0 0 28px;font-size:14px}
|
||||
h2{font-size:14px;font-weight:600;color:#888;text-transform:uppercase;letter-spacing:.05em;margin:20px 0 8px}
|
||||
pre{background:#0f0f10;border:1px solid #2a2a2b;border-radius:8px;padding:16px;font-size:12px;overflow-x:auto;position:relative}
|
||||
.copy-btn{position:absolute;top:8px;right:8px;background:#2a2a2b;border:none;color:#888;padding:4px 10px;border-radius:6px;cursor:pointer;font-size:11px}
|
||||
.copy-btn:hover{color:#e5e5e5}
|
||||
.token-box{background:#0f0f10;border:1px solid #2a2a2b;border-radius:8px;padding:12px 16px;font-family:monospace;font-size:13px;word-break:break-all;margin-bottom:8px}
|
||||
.warn{color:#888;font-size:12px;margin:4px 0 20px}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<h1>MCP Client Connected!</h1>
|
||||
<p class="subtitle">Copy your access token and the config for your MCP client below.</p>
|
||||
|
||||
<h2>Your Access Token</h2>
|
||||
<div class="token-box">${esc(token!)}</div>
|
||||
<p class="warn">Store this securely — it won't be shown again.</p>
|
||||
|
||||
<h2>Claude Desktop <code>claude_desktop_config.json</code></h2>
|
||||
<pre id="claude-cfg">${esc(claudeConfig)}<button class="copy-btn" onclick="copy('claude-cfg')">Copy</button></pre>
|
||||
|
||||
<h2>Codex CLI / opencode config</h2>
|
||||
<pre id="codex-cfg">${esc(codexConfig)}<button class="copy-btn" onclick="copy('codex-cfg')">Copy</button></pre>
|
||||
</div>
|
||||
<script>
|
||||
function copy(id) {
|
||||
const pre = document.getElementById(id);
|
||||
const text = pre.innerText.replace(/Copy$/, '').trim();
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
const btn = pre.querySelector('.copy-btn');
|
||||
btn.textContent = 'Copied!';
|
||||
setTimeout(() => { btn.textContent = 'Copy'; }, 1500);
|
||||
});
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>`;
|
||||
}
|
||||
|
||||
// ── TikTok Login Kit + Content Posting auth flow ───────────────────────────
|
||||
app.get('/auth/tiktok/start', async (req, res) => {
|
||||
if (!TIKTOK_CLIENT_KEY) {
|
||||
@@ -650,22 +822,25 @@ app.get('/auth/tiktok/callback', async (req, res) => {
|
||||
|
||||
// ── Streamable HTTP transport (MCP 1.x standard) ────────────────────────────
|
||||
const httpTransports = new Map<string, StreamableHTTPServerTransport>();
|
||||
const sessionCustomers = new Map<string, Customer>();
|
||||
|
||||
async function createSession(): Promise<StreamableHTTPServerTransport> {
|
||||
async function createSession(customer?: Customer): Promise<StreamableHTTPServerTransport> {
|
||||
const transport = new StreamableHTTPServerTransport({
|
||||
sessionIdGenerator: () => crypto.randomUUID(),
|
||||
onsessioninitialized: (id) => {
|
||||
console.log(`[mcp] Session initialized: ${id}`);
|
||||
httpTransports.set(id, transport);
|
||||
if (customer) sessionCustomers.set(id, customer);
|
||||
},
|
||||
});
|
||||
transport.onclose = () => {
|
||||
if (transport.sessionId) {
|
||||
console.log(`[mcp] Session closed: ${transport.sessionId}`);
|
||||
httpTransports.delete(transport.sessionId);
|
||||
sessionCustomers.delete(transport.sessionId);
|
||||
}
|
||||
};
|
||||
const server = createMcpServer();
|
||||
const server = createMcpServer(customer);
|
||||
await server.connect(transport);
|
||||
return transport;
|
||||
}
|
||||
@@ -674,6 +849,7 @@ app.post('/mcp', requireAuth, async (req, res) => {
|
||||
const sessionId = req.headers['mcp-session-id'] as string | undefined;
|
||||
console.log(`[mcp] POST sessionId=${sessionId ?? 'none'}, isInit=${isInitializeRequest(req.body)}`);
|
||||
|
||||
const reqCustomer = (req as express.Request & { customer?: Customer }).customer;
|
||||
let transport: StreamableHTTPServerTransport;
|
||||
|
||||
if (sessionId && httpTransports.has(sessionId)) {
|
||||
@@ -684,12 +860,12 @@ app.post('/mcp', requireAuth, async (req, res) => {
|
||||
console.warn(`[mcp] Stale session ${sessionId} re-initializing — pod may have restarted`);
|
||||
}
|
||||
console.log(`[mcp] Creating new session`);
|
||||
transport = await createSession();
|
||||
transport = await createSession(reqCustomer);
|
||||
} else {
|
||||
// Stale session ID from a pod restart — transparently create a new session
|
||||
// and handle the request. Our tools are stateless so no context is lost.
|
||||
console.warn(`[mcp] Unknown session ${sessionId ?? '(none)'} — auto-recovering with new session`);
|
||||
transport = await createSession();
|
||||
transport = await createSession(reqCustomer);
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -725,7 +901,8 @@ app.get('/sse', requireAuth, async (req, res) => {
|
||||
const transport = new SSEServerTransport('/messages', res);
|
||||
sseTransports.set(transport.sessionId, transport);
|
||||
res.on('close', () => sseTransports.delete(transport.sessionId));
|
||||
const server = createMcpServer();
|
||||
const sseCustomer = (req as express.Request & { customer?: Customer }).customer;
|
||||
const server = createMcpServer(sseCustomer);
|
||||
await server.connect(transport);
|
||||
});
|
||||
|
||||
@@ -919,7 +1096,12 @@ app.get('/api/whatsapp/templates', requireAuth, async (req, res) => {
|
||||
// ── WhatsApp webhook (multi-tenant) ─────────────────────────────
|
||||
async function handleInboundWhatsAppMessage(event: RoutedWebhookEvent): Promise<void> {
|
||||
console.log(`[webhook/whatsapp] inbound message from=${event.message.from} customer=${event.customerId} type=${event.message.type}`);
|
||||
// Future: route to customer's agent or queue for processing
|
||||
// Fire-and-forget — don't block the webhook acknowledgement
|
||||
deliverWebhook(event.customerId, 'whatsapp', 'inbound_message', {
|
||||
from: event.message.from,
|
||||
text: (event.message as unknown as Record<string, unknown>).text ?? null,
|
||||
timestamp: event.message.timestamp,
|
||||
}).catch((err) => console.error('[webhook/whatsapp] delivery error:', err));
|
||||
}
|
||||
|
||||
// WhatsApp webhook verification (GET)
|
||||
@@ -935,13 +1117,14 @@ app.get('/webhook/whatsapp', (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
// WhatsApp webhook delivery (POST) — multi-tenant routed
|
||||
app.post('/webhook/whatsapp', express.json(), async (req, res) => {
|
||||
// WhatsApp webhook delivery (POST) — raw body preserved for HMAC verification
|
||||
app.post('/webhook/whatsapp', express.raw({ type: '*/*' }), async (req, res) => {
|
||||
// Always acknowledge immediately to prevent Meta retries (20s window)
|
||||
res.status(200).send('EVENT_RECEIVED');
|
||||
|
||||
try {
|
||||
const events = await routeWhatsAppWebhook(req.body as Record<string, unknown>);
|
||||
const body = JSON.parse((req.body as Buffer).toString('utf8')) as Record<string, unknown>;
|
||||
const events = await routeWhatsAppWebhook(body);
|
||||
for (const event of events) {
|
||||
await handleInboundWhatsAppMessage(event);
|
||||
}
|
||||
@@ -1125,15 +1308,62 @@ app.get('/api/connections', meterMiddleware, async (req, res) => {
|
||||
const customer = (req as unknown as { customer: Customer }).customer;
|
||||
const platforms: Platform[] = ['email', 'whatsapp', 'linkedin', 'telegram', 'discord', 'instagram', 'twitter', 'tiktok', 'snapchat', 'facebook', 'obsidian'];
|
||||
|
||||
const status: Record<string, boolean> = {};
|
||||
for (const platform of platforms) {
|
||||
const cred = await customer.getCredential(platform);
|
||||
status[platform] = cred !== null;
|
||||
}
|
||||
const results = await Promise.all(platforms.map((p) => customer.getCredential(p)));
|
||||
const status: Record<string, boolean> = Object.fromEntries(
|
||||
platforms.map((p, i) => [p, results[i] !== null])
|
||||
);
|
||||
|
||||
res.json({ customerId: customer.id, connections: status });
|
||||
});
|
||||
|
||||
app.get('/api/health/platforms', meterMiddleware, async (req, res) => {
|
||||
const customer = (req as unknown as { customer: Customer }).customer;
|
||||
const health = await getAllPlatformHealth(customer.id);
|
||||
res.json({ health });
|
||||
});
|
||||
|
||||
// ── Webhooks ────────────────────────────────────────────────────
|
||||
|
||||
app.get('/api/webhooks/config', meterMiddleware, async (req, res) => {
|
||||
const customer = (req as unknown as { customer: Customer }).customer;
|
||||
const [rows] = await getPool().query<any[]>(
|
||||
'SELECT webhook_url FROM customers WHERE id = ?',
|
||||
[customer.id]
|
||||
);
|
||||
res.json({ webhookUrl: rows[0]?.webhook_url ?? null });
|
||||
});
|
||||
|
||||
app.post('/api/webhooks/config', meterMiddleware, async (req, res) => {
|
||||
const customer = (req as unknown as { customer: Customer }).customer;
|
||||
const webhookUrl = (req.body as Record<string, unknown>).webhook_url as string | undefined;
|
||||
|
||||
if (!webhookUrl) {
|
||||
res.status(400).json({ error: 'webhook_url required' });
|
||||
return;
|
||||
}
|
||||
if (!isValidWebhookUrl(webhookUrl)) {
|
||||
res.status(400).json({ error: 'Invalid webhook URL — must be https:// with a public hostname' });
|
||||
return;
|
||||
}
|
||||
|
||||
const secret = crypto.randomBytes(32).toString('hex');
|
||||
await getPool().query(
|
||||
'UPDATE customers SET webhook_url = ?, webhook_secret = ? WHERE id = ?',
|
||||
[webhookUrl, secret, customer.id]
|
||||
);
|
||||
// Secret returned only at creation/rotation; not retrievable afterward
|
||||
res.json({ webhookUrl, webhookSecret: secret });
|
||||
});
|
||||
|
||||
app.delete('/api/webhooks/config', meterMiddleware, async (req, res) => {
|
||||
const customer = (req as unknown as { customer: Customer }).customer;
|
||||
await getPool().query(
|
||||
'UPDATE customers SET webhook_url = NULL, webhook_secret = NULL WHERE id = ?',
|
||||
[customer.id]
|
||||
);
|
||||
res.json({ deleted: true });
|
||||
});
|
||||
|
||||
// ── Usage & Limits ──────────────────────────────────────────────
|
||||
|
||||
app.get('/api/usage', meterMiddleware, async (req, res) => {
|
||||
@@ -1150,6 +1380,20 @@ app.get('/api/usage', meterMiddleware, async (req, res) => {
|
||||
});
|
||||
});
|
||||
|
||||
app.get('/api/usage/daily', meterMiddleware, async (req, res) => {
|
||||
const customer = (req as unknown as { customer: Customer }).customer;
|
||||
const [rows] = await getPool().query<any[]>(
|
||||
`SELECT DATE(created_at) as date, COUNT(*) as count
|
||||
FROM usage_logs
|
||||
WHERE customer_id = ?
|
||||
AND created_at >= DATE_FORMAT(NOW(), '%Y-%m-01')
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC`,
|
||||
[customer.id]
|
||||
);
|
||||
res.json({ daily: rows });
|
||||
});
|
||||
|
||||
// ── Invoices ────────────────────────────────────────────────────
|
||||
|
||||
app.get('/api/invoices', meterMiddleware, async (req, res) => {
|
||||
@@ -1159,8 +1403,9 @@ app.get('/api/invoices', meterMiddleware, async (req, res) => {
|
||||
});
|
||||
|
||||
app.get('/api/invoices/:number', meterMiddleware, async (req, res) => {
|
||||
const customer = (req as express.Request & { customer?: Customer }).customer;
|
||||
const invoice = await getInvoiceByNumber(req.params.number);
|
||||
if (!invoice) {
|
||||
if (!invoice || invoice.customer_id !== customer?.id) {
|
||||
res.status(404).json({ error: 'Invoice not found' });
|
||||
return;
|
||||
}
|
||||
@@ -1297,6 +1542,13 @@ app.post('/api/admin/invoices/:number/pay', requireAdmin, async (req, res) => {
|
||||
res.json({ paid: true });
|
||||
});
|
||||
|
||||
app.get('/api/admin/webhooks/dlq/:customerId', requireAdmin, async (req, res) => {
|
||||
const { customerId } = req.params;
|
||||
const entries = await redis.lRange(`webhook:dlq:${customerId}`, 0, -1);
|
||||
const events = entries.map((e) => JSON.parse(e) as unknown);
|
||||
res.json({ customerId, events, count: events.length });
|
||||
});
|
||||
|
||||
// ── LinkedIn REST endpoints ─────────────────────────────────────
|
||||
app.get('/api/linkedin/profile', requireAuth, async (req, res) => {
|
||||
const account = req.query.account as string | undefined;
|
||||
@@ -1748,6 +2000,18 @@ app.get('/health', (_req, res) => {
|
||||
async function main() {
|
||||
await initDatabase();
|
||||
|
||||
// Ensure the pre-registered SquareMCP OAuth app exists for the browser DCR flow
|
||||
const oauthClientId = process.env.OAUTH_CLIENT_ID;
|
||||
const oauthClientSecret = process.env.OAUTH_CLIENT_SECRET;
|
||||
if (oauthClientId && oauthClientSecret) {
|
||||
await ensureOAuthAppRegistered(oauthClientId, oauthClientSecret, [
|
||||
`${SERVER_URL}/oauth/mcp-callback`,
|
||||
'http://localhost:*',
|
||||
'claude-desktop://callback',
|
||||
'opencode://callback',
|
||||
]);
|
||||
}
|
||||
|
||||
app.listen(PORT, () => {
|
||||
console.log(`Hermes MCP server running on port ${PORT}`);
|
||||
console.log(` Streamable HTTP: ${SERVER_URL}/mcp`);
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
import { createClient } from 'redis';
|
||||
|
||||
const redis = createClient({ url: process.env.REDIS_URL });
|
||||
redis.connect().catch((err) => console.error('[audit-log] Redis connect error:', err));
|
||||
import redis from '../redis.js';
|
||||
|
||||
export interface AuditEntry {
|
||||
customerId: string;
|
||||
|
||||
101
src/multitenancy/credential-store.test.ts
Normal file
101
src/multitenancy/credential-store.test.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
|
||||
const { mockRedisGet, mockRedisSet, mockRedisDel, mockRedisKeys } = vi.hoisted(() => ({
|
||||
mockRedisGet: vi.fn(),
|
||||
mockRedisSet: vi.fn(),
|
||||
mockRedisDel: vi.fn(),
|
||||
mockRedisKeys: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../redis.js', () => ({
|
||||
default: {
|
||||
get: mockRedisGet,
|
||||
set: mockRedisSet,
|
||||
del: mockRedisDel,
|
||||
keys: mockRedisKeys,
|
||||
},
|
||||
}));
|
||||
|
||||
const mockTryRefreshToken = vi.hoisted(() => vi.fn());
|
||||
vi.mock('./token-refresh.js', () => ({ tryRefreshToken: mockTryRefreshToken }));
|
||||
|
||||
// Use a real 32-byte key so AES-256-GCM doesn't throw
|
||||
process.env.CREDENTIAL_ENCRYPTION_KEY = '0'.repeat(64);
|
||||
|
||||
import { getCredential, storeCredential } from './credential-store.js';
|
||||
|
||||
function encryptCreds(creds: object): string {
|
||||
// We can't call encrypt() directly (not exported), so we'll round-trip through storeCredential
|
||||
// For tests, we'll use a helper approach: just test behavior using storeCredential to set up state.
|
||||
return JSON.stringify(creds); // placeholder — see note below
|
||||
}
|
||||
|
||||
describe('getCredential', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockRedisSet.mockResolvedValue('OK');
|
||||
});
|
||||
|
||||
it('returns null when no credential stored', async () => {
|
||||
mockRedisGet.mockResolvedValue(null);
|
||||
const result = await getCredential('cust1', 'linkedin');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('returns credential when not expired', async () => {
|
||||
// Store then retrieve using real encryption
|
||||
const creds = { accessToken: 'tok', expiresAt: Date.now() + 3_600_000 };
|
||||
await storeCredential('cust1', 'linkedin', creds);
|
||||
const stored = mockRedisSet.mock.calls[0][1] as string;
|
||||
mockRedisGet.mockResolvedValue(stored);
|
||||
|
||||
const result = await getCredential('cust1', 'linkedin');
|
||||
expect(result).toMatchObject({ accessToken: 'tok' });
|
||||
expect(mockTryRefreshToken).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('attempts refresh when token is within 60s of expiry', async () => {
|
||||
const creds = { accessToken: 'old', refreshToken: 'ref', expiresAt: Date.now() + 30_000 };
|
||||
await storeCredential('cust1', 'linkedin', creds);
|
||||
const stored = mockRedisSet.mock.calls[0][1] as string;
|
||||
mockRedisGet.mockResolvedValue(stored);
|
||||
mockTryRefreshToken.mockResolvedValue({ accessToken: 'new', expiresAt: Date.now() + 3_600_000 });
|
||||
|
||||
const result = await getCredential<{ accessToken: string }>('cust1', 'linkedin');
|
||||
expect(mockTryRefreshToken).toHaveBeenCalledWith('cust1', 'linkedin', expect.objectContaining({ accessToken: 'old' }));
|
||||
expect(result?.accessToken).toBe('new');
|
||||
});
|
||||
|
||||
it('returns null when token expired and no refresh token', async () => {
|
||||
const creds = { accessToken: 'old', expiresAt: Date.now() - 1000 };
|
||||
await storeCredential('cust1', 'linkedin', creds);
|
||||
const stored = mockRedisSet.mock.calls[0][1] as string;
|
||||
mockRedisGet.mockResolvedValue(stored);
|
||||
|
||||
const result = await getCredential('cust1', 'linkedin');
|
||||
expect(result).toBeNull();
|
||||
expect(mockTryRefreshToken).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns null when refresh fails', async () => {
|
||||
const creds = { accessToken: 'old', refreshToken: 'ref', expiresAt: Date.now() - 1000 };
|
||||
await storeCredential('cust1', 'linkedin', creds);
|
||||
const stored = mockRedisSet.mock.calls[0][1] as string;
|
||||
mockRedisGet.mockResolvedValue(stored);
|
||||
mockTryRefreshToken.mockResolvedValue(null);
|
||||
|
||||
const result = await getCredential('cust1', 'linkedin');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('returns non-OAuth credentials without expiry check', async () => {
|
||||
const creds = { host: 'imap.gmail.com', port: 993, user: 'u', password: 'p' };
|
||||
await storeCredential('cust1', 'email', creds);
|
||||
const stored = mockRedisSet.mock.calls[0][1] as string;
|
||||
mockRedisGet.mockResolvedValue(stored);
|
||||
|
||||
const result = await getCredential('cust1', 'email');
|
||||
expect(result).toMatchObject({ host: 'imap.gmail.com' });
|
||||
expect(mockTryRefreshToken).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -1,8 +1,5 @@
|
||||
import { createCipheriv, createDecipheriv, randomBytes } from 'crypto';
|
||||
import { createClient } from 'redis';
|
||||
|
||||
const redis = createClient({ url: process.env.REDIS_URL });
|
||||
redis.connect().catch((err) => console.error('[credential-store] Redis connect error:', err));
|
||||
import redis from '../redis.js';
|
||||
|
||||
const ENCRYPTION_KEY = Buffer.from(process.env.CREDENTIAL_ENCRYPTION_KEY ?? '0'.repeat(64), 'hex');
|
||||
// CREDENTIAL_ENCRYPTION_KEY must be a 64-char hex string (32 bytes)
|
||||
@@ -70,7 +67,21 @@ export async function getCredential<T extends PlatformCredentials>(
|
||||
const key = `creds:${customerId}:${platform}`;
|
||||
const encrypted = await redis.get(key);
|
||||
if (!encrypted) return null;
|
||||
return JSON.parse(decrypt(encrypted)) as T;
|
||||
const creds = JSON.parse(decrypt(encrypted)) as T;
|
||||
|
||||
const oauth = creds as OAuthCredentials;
|
||||
if (typeof oauth.accessToken === 'string' && typeof oauth.expiresAt === 'number') {
|
||||
if (oauth.expiresAt < Date.now() + 60_000) {
|
||||
if (oauth.refreshToken) {
|
||||
const { tryRefreshToken } = await import('./token-refresh.js');
|
||||
const refreshed = await tryRefreshToken(customerId, platform, oauth);
|
||||
if (refreshed) return refreshed as T;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return creds;
|
||||
}
|
||||
|
||||
export async function revokeCredential(customerId: string, platform: Platform): Promise<void> {
|
||||
|
||||
113
src/multitenancy/platform-health.test.ts
Normal file
113
src/multitenancy/platform-health.test.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
|
||||
const { mockRedisGet, mockRedisSetEx } = vi.hoisted(() => ({
|
||||
mockRedisGet: vi.fn(),
|
||||
mockRedisSetEx: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../redis.js', () => ({
|
||||
default: {
|
||||
get: mockRedisGet,
|
||||
setEx: mockRedisSetEx,
|
||||
},
|
||||
}));
|
||||
|
||||
const mockGetCredential = vi.hoisted(() => vi.fn());
|
||||
vi.mock('./credential-store.js', () => ({ getCredential: mockGetCredential }));
|
||||
|
||||
global.fetch = vi.fn();
|
||||
|
||||
import { getAllPlatformHealth } from './platform-health.js';
|
||||
|
||||
describe('getAllPlatformHealth', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockRedisGet.mockResolvedValue(null);
|
||||
mockRedisSetEx.mockResolvedValue('OK');
|
||||
});
|
||||
|
||||
it('returns cached status without hitting API', async () => {
|
||||
mockRedisGet.mockImplementation((key: string) =>
|
||||
key.includes('linkedin') ? Promise.resolve('healthy') : Promise.resolve(null)
|
||||
);
|
||||
mockGetCredential.mockResolvedValue(null);
|
||||
|
||||
const results = await getAllPlatformHealth('cust1');
|
||||
const linkedin = results.find(r => r.platform === 'linkedin');
|
||||
expect(linkedin?.status).toBe('healthy');
|
||||
expect(fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns disconnected when no credential stored', async () => {
|
||||
mockGetCredential.mockResolvedValue(null);
|
||||
// Second redis.get (raw key check) also returns null
|
||||
mockRedisGet.mockResolvedValue(null);
|
||||
|
||||
const results = await getAllPlatformHealth('cust1');
|
||||
results.forEach(r => {
|
||||
expect(r.status).toBe('disconnected');
|
||||
});
|
||||
});
|
||||
|
||||
it('returns healthy when OAuth probe succeeds', async () => {
|
||||
mockGetCredential.mockImplementation((id: string, platform: string) =>
|
||||
platform === 'linkedin' ? Promise.resolve({ accessToken: 'tok' }) : Promise.resolve(null)
|
||||
);
|
||||
(fetch as ReturnType<typeof vi.fn>).mockResolvedValue({ ok: true });
|
||||
mockRedisGet.mockImplementation((key: string) =>
|
||||
key.startsWith('creds:') ? Promise.resolve(null) : Promise.resolve(null)
|
||||
);
|
||||
|
||||
const results = await getAllPlatformHealth('cust1');
|
||||
const linkedin = results.find(r => r.platform === 'linkedin');
|
||||
expect(linkedin?.status).toBe('healthy');
|
||||
});
|
||||
|
||||
it('returns expired when OAuth probe returns non-ok', async () => {
|
||||
mockGetCredential.mockImplementation((_id: string, platform: string) =>
|
||||
platform === 'twitter' ? Promise.resolve({ accessToken: 'tok' }) : Promise.resolve(null)
|
||||
);
|
||||
(fetch as ReturnType<typeof vi.fn>).mockResolvedValue({ ok: false, status: 401 });
|
||||
mockRedisGet.mockResolvedValue(null);
|
||||
|
||||
const results = await getAllPlatformHealth('cust1');
|
||||
const twitter = results.find(r => r.platform === 'twitter');
|
||||
expect(twitter?.status).toBe('expired');
|
||||
});
|
||||
|
||||
it('returns unknown when fetch throws', async () => {
|
||||
mockGetCredential.mockImplementation((_id: string, platform: string) =>
|
||||
platform === 'tiktok' ? Promise.resolve({ accessToken: 'tok' }) : Promise.resolve(null)
|
||||
);
|
||||
(fetch as ReturnType<typeof vi.fn>).mockRejectedValue(new Error('network error'));
|
||||
mockRedisGet.mockResolvedValue(null);
|
||||
|
||||
const results = await getAllPlatformHealth('cust1');
|
||||
const tiktok = results.find(r => r.platform === 'tiktok');
|
||||
expect(tiktok?.status).toBe('unknown');
|
||||
});
|
||||
|
||||
it('caches the result in Redis', async () => {
|
||||
mockGetCredential.mockResolvedValue(null);
|
||||
mockRedisGet.mockResolvedValue(null);
|
||||
|
||||
await getAllPlatformHealth('cust1');
|
||||
expect(mockRedisSetEx).toHaveBeenCalledWith(
|
||||
expect.stringContaining('health:cust1:'),
|
||||
600,
|
||||
expect.any(String)
|
||||
);
|
||||
});
|
||||
|
||||
it('returns healthy for non-OAuth platforms when credential exists', async () => {
|
||||
mockGetCredential.mockImplementation((_id: string, platform: string) =>
|
||||
platform === 'email' ? Promise.resolve({ host: 'smtp.example.com', port: 587, user: 'u', password: 'p' }) : Promise.resolve(null)
|
||||
);
|
||||
mockRedisGet.mockResolvedValue(null);
|
||||
|
||||
const results = await getAllPlatformHealth('cust1');
|
||||
const email = results.find(r => r.platform === 'email');
|
||||
expect(email?.status).toBe('healthy');
|
||||
expect(fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
72
src/multitenancy/platform-health.ts
Normal file
72
src/multitenancy/platform-health.ts
Normal file
@@ -0,0 +1,72 @@
|
||||
import redis from '../redis.js';
|
||||
import { getCredential, type Platform } from './credential-store.js';
|
||||
|
||||
const HEALTH_TTL = 600; // 10 minutes
|
||||
|
||||
type HealthStatus = 'healthy' | 'expired' | 'disconnected' | 'unknown';
|
||||
|
||||
interface PlatformHealth {
|
||||
platform: Platform;
|
||||
status: HealthStatus;
|
||||
}
|
||||
|
||||
const OAUTH_PLATFORMS: Platform[] = ['linkedin', 'twitter', 'tiktok', 'instagram', 'facebook', 'snapchat'];
|
||||
const ALL_PLATFORMS: Platform[] = ['email', 'whatsapp', 'linkedin', 'telegram', 'discord', 'instagram', 'twitter', 'tiktok', 'snapchat', 'facebook', 'obsidian'];
|
||||
|
||||
async function checkPlatformHealth(customerId: string, platform: Platform): Promise<HealthStatus> {
|
||||
const cacheKey = `health:${customerId}:${platform}`;
|
||||
const cached = await redis.get(cacheKey);
|
||||
if (cached) return cached as HealthStatus;
|
||||
|
||||
const cred = await getCredential(customerId, platform);
|
||||
let status: HealthStatus;
|
||||
|
||||
if (!cred) {
|
||||
// getCredential returns null for both "not stored" and "expired with no refresh"
|
||||
// Check if there was a stored (but expired) credential by looking at the raw key
|
||||
const rawKey = `creds:${customerId}:${platform}`;
|
||||
const raw = await redis.get(rawKey);
|
||||
status = raw ? 'expired' : 'disconnected';
|
||||
} else if (OAUTH_PLATFORMS.includes(platform)) {
|
||||
// Credential exists and is not expired — probe the API
|
||||
status = await probeOAuthPlatform(platform, cred as { accessToken: string });
|
||||
} else {
|
||||
status = 'healthy';
|
||||
}
|
||||
|
||||
await redis.setEx(cacheKey, HEALTH_TTL, status);
|
||||
return status;
|
||||
}
|
||||
|
||||
async function probeOAuthPlatform(platform: Platform, cred: { accessToken: string }): Promise<HealthStatus> {
|
||||
const probeUrls: Partial<Record<Platform, string>> = {
|
||||
linkedin: 'https://api.linkedin.com/v2/userinfo',
|
||||
twitter: 'https://api.twitter.com/2/users/me',
|
||||
tiktok: 'https://open.tiktokapis.com/v2/user/info/?fields=open_id',
|
||||
instagram: 'https://graph.instagram.com/me?fields=id',
|
||||
facebook: 'https://graph.facebook.com/me?fields=id',
|
||||
snapchat: 'https://adsapi.snapchat.com/v1/me',
|
||||
};
|
||||
|
||||
const url = probeUrls[platform];
|
||||
if (!url) return 'unknown';
|
||||
|
||||
try {
|
||||
const res = await fetch(url, {
|
||||
headers: { Authorization: `Bearer ${cred.accessToken}` },
|
||||
signal: AbortSignal.timeout(8000),
|
||||
});
|
||||
return res.ok ? 'healthy' : 'expired';
|
||||
} catch {
|
||||
return 'unknown';
|
||||
}
|
||||
}
|
||||
|
||||
export async function getAllPlatformHealth(customerId: string): Promise<PlatformHealth[]> {
|
||||
return Promise.all(
|
||||
ALL_PLATFORMS.map(async (platform) => ({
|
||||
platform,
|
||||
status: await checkPlatformHealth(customerId, platform),
|
||||
}))
|
||||
);
|
||||
}
|
||||
102
src/multitenancy/token-refresh.ts
Normal file
102
src/multitenancy/token-refresh.ts
Normal file
@@ -0,0 +1,102 @@
|
||||
import { storeCredential, type OAuthCredentials, type Platform } from './credential-store.js';
|
||||
|
||||
interface TokenResponse {
|
||||
access_token: string;
|
||||
refresh_token?: string;
|
||||
expires_in?: number;
|
||||
}
|
||||
|
||||
async function postRefresh(url: string, body: URLSearchParams): Promise<TokenResponse | null> {
|
||||
try {
|
||||
const res = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: body.toString(),
|
||||
signal: AbortSignal.timeout(15000),
|
||||
});
|
||||
if (!res.ok) {
|
||||
console.warn(`[token-refresh] HTTP ${res.status} from ${url}`);
|
||||
return null;
|
||||
}
|
||||
return await res.json() as TokenResponse;
|
||||
} catch (err) {
|
||||
console.error(`[token-refresh] request failed:`, err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export async function tryRefreshToken(
|
||||
customerId: string,
|
||||
platform: Platform,
|
||||
creds: OAuthCredentials
|
||||
): Promise<OAuthCredentials | null> {
|
||||
let data: TokenResponse | null = null;
|
||||
|
||||
if (platform === 'linkedin') {
|
||||
data = await postRefresh('https://www.linkedin.com/oauth/v2/accessToken', new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: creds.refreshToken!,
|
||||
client_id: process.env.LINKEDIN_CLIENT_ID ?? '',
|
||||
client_secret: process.env.LINKEDIN_CLIENT_SECRET ?? '',
|
||||
}));
|
||||
|
||||
} else if (platform === 'twitter') {
|
||||
const credentials = Buffer.from(
|
||||
`${process.env.TWITTER_CLIENT_ID ?? ''}:${process.env.TWITTER_CLIENT_SECRET ?? ''}`
|
||||
).toString('base64');
|
||||
try {
|
||||
const res = await fetch('https://api.twitter.com/2/oauth2/token', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Authorization': `Basic ${credentials}`,
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: creds.refreshToken!,
|
||||
}).toString(),
|
||||
signal: AbortSignal.timeout(15000),
|
||||
});
|
||||
if (res.ok) data = await res.json() as TokenResponse;
|
||||
else console.warn(`[token-refresh] twitter HTTP ${res.status}`);
|
||||
} catch (err) {
|
||||
console.error(`[token-refresh] twitter error:`, err);
|
||||
}
|
||||
|
||||
} else if (platform === 'tiktok') {
|
||||
data = await postRefresh('https://open.tiktokapis.com/v2/oauth/token/', new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: creds.refreshToken!,
|
||||
client_key: process.env.TIKTOK_CLIENT_KEY ?? '',
|
||||
client_secret: process.env.TIKTOK_CLIENT_SECRET ?? '',
|
||||
}));
|
||||
|
||||
} else if (platform === 'instagram' || platform === 'facebook') {
|
||||
// Facebook long-lived token exchange uses the current access token, not refresh token
|
||||
try {
|
||||
const url = new URL('https://graph.facebook.com/oauth/access_token');
|
||||
url.searchParams.set('grant_type', 'fb_exchange_token');
|
||||
url.searchParams.set('client_id', process.env.FACEBOOK_APP_ID ?? '');
|
||||
url.searchParams.set('client_secret', process.env.FACEBOOK_APP_SECRET ?? '');
|
||||
url.searchParams.set('fb_exchange_token', creds.accessToken);
|
||||
const res = await fetch(url.toString(), { signal: AbortSignal.timeout(15000) });
|
||||
if (res.ok) data = await res.json() as TokenResponse;
|
||||
else console.warn(`[token-refresh] ${platform} HTTP ${res.status}`);
|
||||
} catch (err) {
|
||||
console.error(`[token-refresh] ${platform} error:`, err);
|
||||
}
|
||||
}
|
||||
|
||||
if (!data?.access_token) return null;
|
||||
|
||||
const refreshed: OAuthCredentials = {
|
||||
accessToken: data.access_token,
|
||||
refreshToken: data.refresh_token ?? creds.refreshToken,
|
||||
expiresAt: data.expires_in ? Date.now() + data.expires_in * 1000 : undefined,
|
||||
scope: creds.scope,
|
||||
};
|
||||
|
||||
await storeCredential(customerId, platform, refreshed);
|
||||
console.log(`[token-refresh] ${platform} refreshed for customer ${customerId}`);
|
||||
return refreshed;
|
||||
}
|
||||
@@ -1,9 +1,6 @@
|
||||
import { createClient } from 'redis';
|
||||
import redis from '../redis.js';
|
||||
import { getCredential, WhatsAppCredentials } from './credential-store.js';
|
||||
|
||||
const redis = createClient({ url: process.env.REDIS_URL });
|
||||
redis.connect().catch((err) => console.error('[webhook-router] Redis connect error:', err));
|
||||
|
||||
// Call this at customer onboarding when they connect their WhatsApp Business number
|
||||
export async function registerWhatsAppNumber(
|
||||
customerId: string,
|
||||
|
||||
81
src/oauth-dcr.test.ts
Normal file
81
src/oauth-dcr.test.ts
Normal file
@@ -0,0 +1,81 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import type { ResultSetHeader } from 'mysql2';
|
||||
|
||||
const { mockExecute } = vi.hoisted(() => ({ mockExecute: vi.fn() }));
|
||||
|
||||
vi.mock('./db.js', () => ({
|
||||
getPool: vi.fn(() => ({ execute: mockExecute })),
|
||||
isPoolReady: vi.fn(() => true),
|
||||
}));
|
||||
|
||||
import { ensureOAuthAppRegistered, isValidRedirectUri } from './oauth.js';
|
||||
|
||||
function affected(n: number): [ResultSetHeader, unknown[]] {
|
||||
return [{ affectedRows: n } as ResultSetHeader, []];
|
||||
}
|
||||
|
||||
describe('ensureOAuthAppRegistered', () => {
|
||||
beforeEach(() => vi.clearAllMocks());
|
||||
|
||||
it('executes an INSERT ... ON DUPLICATE KEY UPDATE', async () => {
|
||||
mockExecute.mockResolvedValue(affected(1));
|
||||
|
||||
await ensureOAuthAppRegistered('client-1', 'secret-1', [
|
||||
'https://app.example.com/oauth/mcp-callback',
|
||||
'http://localhost:*',
|
||||
'claude-desktop://callback',
|
||||
]);
|
||||
|
||||
const [sql, params] = mockExecute.mock.calls[0] as [string, unknown[]];
|
||||
expect(sql).toContain('INSERT INTO oauth_clients');
|
||||
expect(sql).toContain('ON DUPLICATE KEY UPDATE');
|
||||
expect(params[0]).toBe('client-1');
|
||||
expect(params[1]).toBe('secret-1');
|
||||
const redirectUris = JSON.parse(params[3] as string);
|
||||
expect(redirectUris).toContain('http://localhost:*');
|
||||
expect(redirectUris).toContain('claude-desktop://callback');
|
||||
});
|
||||
|
||||
it('tolerates DB errors gracefully (does not throw)', async () => {
|
||||
// Should propagate — not silently swallow — so callers can log
|
||||
mockExecute.mockRejectedValue(new Error('DB gone'));
|
||||
await expect(
|
||||
ensureOAuthAppRegistered('c', 's', [])
|
||||
).rejects.toThrow('DB gone');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isValidRedirectUri — DCR client redirect URIs', () => {
|
||||
const registered = [
|
||||
'https://hermes.squaremcp.com/oauth/mcp-callback',
|
||||
'http://localhost:*',
|
||||
'claude-desktop://callback',
|
||||
'opencode://callback',
|
||||
];
|
||||
|
||||
it('allows the SquareMCP mcp-callback URI exactly', () => {
|
||||
expect(isValidRedirectUri('https://hermes.squaremcp.com/oauth/mcp-callback', registered)).toBe(true);
|
||||
});
|
||||
|
||||
it('allows claude-desktop://callback exactly', () => {
|
||||
expect(isValidRedirectUri('claude-desktop://callback', registered)).toBe(true);
|
||||
});
|
||||
|
||||
it('allows opencode://callback exactly', () => {
|
||||
expect(isValidRedirectUri('opencode://callback', registered)).toBe(true);
|
||||
});
|
||||
|
||||
it('allows any localhost port via http://localhost:* wildcard', () => {
|
||||
expect(isValidRedirectUri('http://localhost:3000', registered)).toBe(true);
|
||||
expect(isValidRedirectUri('http://localhost:52481/callback', registered)).toBe(true);
|
||||
});
|
||||
|
||||
it('rejects unregistered URIs', () => {
|
||||
expect(isValidRedirectUri('https://evil.com', registered)).toBe(false);
|
||||
expect(isValidRedirectUri('https://hermes.squaremcp.com/other', registered)).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects https://localhost (SSL loopback) — not in registered list', () => {
|
||||
expect(isValidRedirectUri('https://localhost:3000', registered)).toBe(false);
|
||||
});
|
||||
});
|
||||
192
src/oauth.test.ts
Normal file
192
src/oauth.test.ts
Normal file
@@ -0,0 +1,192 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import type { ResultSetHeader } from 'mysql2';
|
||||
|
||||
// ── DB mock ────────────────────────────────────────────────────────────────
|
||||
const { mockExecute } = vi.hoisted(() => ({ mockExecute: vi.fn() }));
|
||||
|
||||
vi.mock('./db.js', () => ({
|
||||
getPool: vi.fn(() => ({ execute: mockExecute })),
|
||||
isPoolReady: vi.fn(() => true),
|
||||
}));
|
||||
|
||||
import {
|
||||
createAuthCode,
|
||||
exchangeCodeForToken,
|
||||
validateAccessToken,
|
||||
getTokenCustomer,
|
||||
isValidRedirectUri,
|
||||
} from './oauth.js';
|
||||
|
||||
// Helper: build a minimal ResultSetHeader-shaped object
|
||||
function affected(n: number): [ResultSetHeader, unknown[]] {
|
||||
return [{ affectedRows: n } as ResultSetHeader, []];
|
||||
}
|
||||
|
||||
// Helper: build a row-result
|
||||
function rows(data: object[]): [object[], unknown[]] {
|
||||
return [data, []];
|
||||
}
|
||||
|
||||
describe('isValidRedirectUri', () => {
|
||||
it('exact match is valid', () => {
|
||||
expect(isValidRedirectUri('https://app.example.com/callback', ['https://app.example.com/callback'])).toBe(true);
|
||||
});
|
||||
|
||||
it('http://localhost:* wildcard matches any localhost port', () => {
|
||||
expect(isValidRedirectUri('http://localhost:3000', ['http://localhost:*'])).toBe(true);
|
||||
expect(isValidRedirectUri('http://localhost:9876/callback', ['http://localhost:*'])).toBe(true);
|
||||
});
|
||||
|
||||
it('http://localhost:* does not match non-localhost', () => {
|
||||
expect(isValidRedirectUri('http://evil.com', ['http://localhost:*'])).toBe(false);
|
||||
expect(isValidRedirectUri('https://localhost:3000', ['http://localhost:*'])).toBe(false);
|
||||
});
|
||||
|
||||
it('unregistered URI is rejected', () => {
|
||||
expect(isValidRedirectUri('https://attacker.com', ['https://app.example.com/callback'])).toBe(false);
|
||||
});
|
||||
|
||||
it('other wildcards are not supported', () => {
|
||||
expect(isValidRedirectUri('https://any.example.com', ['https://*.example.com'])).toBe(false);
|
||||
});
|
||||
|
||||
it('empty registered list rejects everything', () => {
|
||||
expect(isValidRedirectUri('https://app.example.com/callback', [])).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createAuthCode', () => {
|
||||
beforeEach(() => vi.clearAllMocks());
|
||||
|
||||
it('inserts with customer_id when provided', async () => {
|
||||
mockExecute
|
||||
// getClient SELECT (called by nothing here; we call createAuthCode directly)
|
||||
.mockResolvedValueOnce([[], []]); // INSERT returns ResultSetHeader but we ignore it
|
||||
|
||||
// The first execute is the INSERT
|
||||
mockExecute.mockResolvedValue([{ insertId: 0, affectedRows: 1 }, []]);
|
||||
|
||||
const code = await createAuthCode('client1', 'http://localhost:3000', 'read', undefined, undefined, 'cust-42');
|
||||
|
||||
expect(code.customer_id).toBe('cust-42');
|
||||
const [sql, params] = mockExecute.mock.calls[0] as [string, unknown[]];
|
||||
expect(sql).toContain('customer_id');
|
||||
expect(params).toContain('cust-42');
|
||||
});
|
||||
|
||||
it('inserts NULL customer_id when not provided', async () => {
|
||||
mockExecute.mockResolvedValue([{ insertId: 0, affectedRows: 1 }, []]);
|
||||
|
||||
const code = await createAuthCode('client1', 'http://localhost:3000');
|
||||
expect(code.customer_id).toBeUndefined();
|
||||
|
||||
const [, params] = mockExecute.mock.calls[0] as [string, unknown[]];
|
||||
// Last param should be null
|
||||
expect(params[params.length - 1]).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('exchangeCodeForToken — replay protection', () => {
|
||||
beforeEach(() => vi.clearAllMocks());
|
||||
|
||||
const fakeClient = {
|
||||
client_id: 'c1',
|
||||
client_secret: 'secret',
|
||||
client_name: 'Test',
|
||||
redirect_uris: ['http://localhost:3000'],
|
||||
created_at: Date.now(),
|
||||
};
|
||||
const fakeAuthCode = {
|
||||
code: 'abc',
|
||||
client_id: 'c1',
|
||||
redirect_uri: 'http://localhost:3000',
|
||||
scope: 'read',
|
||||
code_challenge: null,
|
||||
code_challenge_method: null,
|
||||
customer_id: 'cust-1',
|
||||
expires_at: new Date(Date.now() + 60_000),
|
||||
used: true,
|
||||
};
|
||||
|
||||
it('returns null when UPDATE affectedRows = 0 (already used)', async () => {
|
||||
// getClient SELECT
|
||||
mockExecute.mockResolvedValueOnce(rows([fakeClient]));
|
||||
// UPDATE oauth_clients last_used
|
||||
mockExecute.mockResolvedValueOnce(affected(1));
|
||||
// Atomic UPDATE auth code — 0 rows (already used)
|
||||
mockExecute.mockResolvedValueOnce(affected(0));
|
||||
|
||||
const result = await exchangeCodeForToken('c1', 'secret', 'abc', 'http://localhost:3000');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('issues token and threads customer_id on success', async () => {
|
||||
// getClient SELECT
|
||||
mockExecute.mockResolvedValueOnce(rows([fakeClient]));
|
||||
// getClient UPDATE last_used
|
||||
mockExecute.mockResolvedValueOnce(affected(1));
|
||||
// Atomic UPDATE auth code — 1 row consumed
|
||||
mockExecute.mockResolvedValueOnce(affected(1));
|
||||
// SELECT auth code data
|
||||
mockExecute.mockResolvedValueOnce(rows([fakeAuthCode]));
|
||||
// INSERT token
|
||||
mockExecute.mockResolvedValueOnce(affected(1));
|
||||
|
||||
const token = await exchangeCodeForToken('c1', 'secret', 'abc', 'http://localhost:3000');
|
||||
expect(token).not.toBeNull();
|
||||
expect(token!.access_token).toBeTruthy();
|
||||
|
||||
// Verify INSERT includes customer_id = 'cust-1'
|
||||
const insertCall = mockExecute.mock.calls.find((c) => (c[0] as string).includes('INSERT INTO oauth_tokens'));
|
||||
expect(insertCall).toBeDefined();
|
||||
const insertParams = insertCall![1] as unknown[];
|
||||
expect(insertParams).toContain('cust-1');
|
||||
});
|
||||
|
||||
it('returns null when client/redirect mismatch', async () => {
|
||||
mockExecute.mockResolvedValueOnce(rows([fakeClient]));
|
||||
mockExecute.mockResolvedValueOnce(affected(1));
|
||||
mockExecute.mockResolvedValueOnce(affected(1));
|
||||
// SELECT returns code with different client_id
|
||||
mockExecute.mockResolvedValueOnce(rows([{ ...fakeAuthCode, client_id: 'other' }]));
|
||||
|
||||
const result = await exchangeCodeForToken('c1', 'secret', 'abc', 'http://localhost:3000');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTokenCustomer', () => {
|
||||
beforeEach(() => vi.clearAllMocks());
|
||||
|
||||
it('returns customerId when token has one', async () => {
|
||||
mockExecute.mockResolvedValue(rows([{ customer_id: 'cust-99' }]));
|
||||
const result = await getTokenCustomer('tok-abc');
|
||||
expect(result).toEqual({ customerId: 'cust-99' });
|
||||
});
|
||||
|
||||
it('returns null when token not found', async () => {
|
||||
mockExecute.mockResolvedValue(rows([]));
|
||||
const result = await getTokenCustomer('bad-token');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null when token has no customer_id', async () => {
|
||||
mockExecute.mockResolvedValue(rows([{ customer_id: null }]));
|
||||
const result = await getTokenCustomer('tok-legacy');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateAccessToken', () => {
|
||||
beforeEach(() => vi.clearAllMocks());
|
||||
|
||||
it('returns true for valid token', async () => {
|
||||
mockExecute.mockResolvedValue(rows([{ token: 'tok' }]));
|
||||
expect(await validateAccessToken('tok')).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false for expired/missing token', async () => {
|
||||
mockExecute.mockResolvedValue(rows([]));
|
||||
expect(await validateAccessToken('bad')).toBe(false);
|
||||
});
|
||||
});
|
||||
90
src/oauth.ts
90
src/oauth.ts
@@ -21,6 +21,7 @@ interface AuthCode {
|
||||
code_challenge?: string;
|
||||
code_challenge_method?: string;
|
||||
expires_at: number;
|
||||
customer_id?: string;
|
||||
}
|
||||
|
||||
interface Token {
|
||||
@@ -72,6 +73,23 @@ function verifyPkce(codeVerifier: string, storedChallenge: string, method?: stri
|
||||
return false;
|
||||
}
|
||||
|
||||
export async function ensureOAuthAppRegistered(
|
||||
clientId: string,
|
||||
clientSecret: string,
|
||||
redirectUris: string[]
|
||||
): Promise<void> {
|
||||
const pool = getPool();
|
||||
await pool.execute(
|
||||
`INSERT INTO oauth_clients (client_id, client_secret, client_name, redirect_urls, grant_types)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
client_secret = VALUES(client_secret),
|
||||
redirect_urls = VALUES(redirect_urls)`,
|
||||
[clientId, clientSecret, 'SquareMCP App', JSON.stringify(redirectUris), JSON.stringify(['authorization_code'])]
|
||||
);
|
||||
console.log(`[oauth] Pre-registered app ${clientId} ensured`);
|
||||
}
|
||||
|
||||
export async function registerClient(body: {
|
||||
client_name?: string;
|
||||
redirect_uris?: string[];
|
||||
@@ -139,7 +157,8 @@ export async function createAuthCode(
|
||||
redirectUri: string,
|
||||
scope?: string,
|
||||
codeChallenge?: string,
|
||||
codeChallengeMethod?: string
|
||||
codeChallengeMethod?: string,
|
||||
customerId?: string
|
||||
): Promise<AuthCode> {
|
||||
const code: AuthCode = {
|
||||
code: generateAuthCode(),
|
||||
@@ -149,18 +168,27 @@ export async function createAuthCode(
|
||||
code_challenge: codeChallenge,
|
||||
code_challenge_method: codeChallengeMethod,
|
||||
expires_at: Date.now() + AUTH_CODE_EXPIRY_MS,
|
||||
customer_id: customerId,
|
||||
};
|
||||
|
||||
const pool = getPool();
|
||||
await pool.execute(
|
||||
'INSERT INTO oauth_auth_codes (code, client_id, redirect_uri, scope, code_challenge, code_challenge_method, expires_at) VALUES (?, ?, ?, ?, ?, ?, ?)',
|
||||
[code.code, clientId, redirectUri, scope || null, codeChallenge || null, codeChallengeMethod || null, new Date(code.expires_at)]
|
||||
'INSERT INTO oauth_auth_codes (code, client_id, redirect_uri, scope, code_challenge, code_challenge_method, expires_at, customer_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?)',
|
||||
[code.code, clientId, redirectUri, scope || null, codeChallenge || null, codeChallengeMethod || null, new Date(code.expires_at), customerId || null]
|
||||
);
|
||||
|
||||
console.log(`[oauth] Created auth code ${code.code.slice(0, 8)}... for client ${clientId}`);
|
||||
return code;
|
||||
}
|
||||
|
||||
export function isValidRedirectUri(uri: string, registeredUris: string[]): boolean {
|
||||
for (const registered of registeredUris) {
|
||||
if (registered === uri) return true;
|
||||
if (registered === 'http://localhost:*' && /^http:\/\/localhost:\d+(\/|$)/.test(uri)) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
export async function exchangeCodeForToken(
|
||||
clientId: string,
|
||||
clientSecret: string | undefined,
|
||||
@@ -192,19 +220,27 @@ export async function exchangeCodeForToken(
|
||||
}
|
||||
|
||||
const pool = getPool();
|
||||
const db = await pool.getConnection();
|
||||
try {
|
||||
const [rows] = await db.execute<RowDataPacket[]>(
|
||||
'SELECT * FROM oauth_auth_codes WHERE code = ? AND used = FALSE AND expires_at > NOW()',
|
||||
// Atomic consume: only one concurrent request can win this UPDATE
|
||||
const [updateResult] = await pool.execute<import('mysql2').ResultSetHeader>(
|
||||
'UPDATE oauth_auth_codes SET used = TRUE WHERE code = ? AND used = FALSE AND expires_at > NOW()',
|
||||
[code]
|
||||
);
|
||||
|
||||
if (!Array.isArray(rows) || rows.length === 0) {
|
||||
console.log('[oauth] Auth code not found or expired');
|
||||
if (updateResult.affectedRows === 0) {
|
||||
console.log('[oauth] Auth code not found, expired, or already used');
|
||||
return null;
|
||||
}
|
||||
|
||||
// Fetch the row now that it is consumed
|
||||
const [rows] = await pool.execute<RowDataPacket[]>(
|
||||
'SELECT * FROM oauth_auth_codes WHERE code = ?',
|
||||
[code]
|
||||
);
|
||||
if (!Array.isArray(rows) || rows.length === 0) {
|
||||
return null;
|
||||
}
|
||||
const authCode = rows[0];
|
||||
|
||||
if (authCode.client_id !== clientId || authCode.redirect_uri !== redirectUri) {
|
||||
console.log('[oauth] Auth code client/redirect mismatch');
|
||||
return null;
|
||||
@@ -215,16 +251,12 @@ export async function exchangeCodeForToken(
|
||||
console.log('[oauth] Missing code_verifier for PKCE exchange');
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!verifyPkce(codeVerifier, authCode.code_challenge, authCode.code_challenge_method || undefined)) {
|
||||
console.log('[oauth] PKCE verification failed');
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
// Mark auth code as used
|
||||
await db.execute('UPDATE oauth_auth_codes SET used = TRUE WHERE code = ?', [code]);
|
||||
|
||||
const token: Token = {
|
||||
access_token: generateAccessToken(),
|
||||
token_type: 'Bearer',
|
||||
@@ -233,9 +265,9 @@ export async function exchangeCodeForToken(
|
||||
expires_at: Date.now() + TOKEN_EXPIRY_MS,
|
||||
};
|
||||
|
||||
await db.execute(
|
||||
'INSERT INTO oauth_tokens (token, client_id, token_type, expires_at) VALUES (?, ?, ?, ?)',
|
||||
[token.access_token, clientId, 'access', new Date(token.expires_at)]
|
||||
await pool.execute(
|
||||
'INSERT INTO oauth_tokens (token, client_id, token_type, expires_at, customer_id) VALUES (?, ?, ?, ?, ?)',
|
||||
[token.access_token, clientId, 'access', new Date(token.expires_at), authCode.customer_id || null]
|
||||
);
|
||||
|
||||
console.log(`[oauth] Issued token ${token.access_token.slice(0, 8)}... for client ${clientId}`);
|
||||
@@ -243,8 +275,6 @@ export async function exchangeCodeForToken(
|
||||
} catch (err) {
|
||||
console.error('[oauth] exchangeCodeForToken error:', err);
|
||||
return null;
|
||||
} finally {
|
||||
db.release();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,21 +282,31 @@ export async function validateAccessToken(tokenValue: string): Promise<boolean>
|
||||
try {
|
||||
const pool = getPool();
|
||||
const [rows] = await pool.execute<RowDataPacket[]>(
|
||||
'SELECT * FROM oauth_tokens WHERE token = ? AND expires_at > NOW()',
|
||||
'SELECT token FROM oauth_tokens WHERE token = ? AND expires_at > NOW()',
|
||||
[tokenValue]
|
||||
);
|
||||
|
||||
if (!Array.isArray(rows) || rows.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return Array.isArray(rows) && rows.length > 0;
|
||||
} catch (err) {
|
||||
console.error('[oauth] validateAccessToken error:', err);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function getTokenCustomer(tokenValue: string): Promise<{ customerId: string } | null> {
|
||||
try {
|
||||
const pool = getPool();
|
||||
const [rows] = await pool.execute<RowDataPacket[]>(
|
||||
'SELECT customer_id FROM oauth_tokens WHERE token = ? AND expires_at > NOW()',
|
||||
[tokenValue]
|
||||
);
|
||||
if (!Array.isArray(rows) || rows.length === 0 || !rows[0].customer_id) return null;
|
||||
return { customerId: rows[0].customer_id as string };
|
||||
} catch (err) {
|
||||
console.error('[oauth] getTokenCustomer error:', err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function getAuthorizeHtml(params: {
|
||||
client_id: string;
|
||||
redirect_uri: string;
|
||||
|
||||
6
src/redis.ts
Normal file
6
src/redis.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
import { createClient } from 'redis';
|
||||
|
||||
const redis = createClient({ url: process.env.REDIS_URL });
|
||||
redis.connect().catch((err) => console.error('[redis] connect error:', err));
|
||||
|
||||
export default redis;
|
||||
107
src/smtp.ts
107
src/smtp.ts
@@ -1,5 +1,6 @@
|
||||
import nodemailer from 'nodemailer';
|
||||
import type { Account } from './imap.js';
|
||||
import type { Account, EmailCtx } from './imap.js';
|
||||
import type { EmailCredentials } from './multitenancy/credential-store.js';
|
||||
|
||||
const FETCHERPAY_SMTP_HOST = process.env['FETCHERPAY_SMTP_HOST'] ?? 'mail.fetcherpay.com';
|
||||
const FETCHERPAY_SMTP_PORT = parseInt(process.env['FETCHERPAY_SMTP_PORT'] ?? '30587');
|
||||
@@ -8,13 +9,13 @@ function fetcherpaySmtpTransport(user: string, pass: string) {
|
||||
return nodemailer.createTransport({
|
||||
host: FETCHERPAY_SMTP_HOST,
|
||||
port: FETCHERPAY_SMTP_PORT,
|
||||
secure: false, // STARTTLS
|
||||
secure: false,
|
||||
auth: { user, pass },
|
||||
tls: { rejectUnauthorized: false },
|
||||
});
|
||||
}
|
||||
|
||||
function getSmtpTransport(account: Account = 'yahoo') {
|
||||
function getEnvSmtpTransport(account: Account = 'yahoo') {
|
||||
switch (account) {
|
||||
case 'fetcherpay':
|
||||
return fetcherpaySmtpTransport(process.env['FETCHERPAY_EMAIL']!, process.env['FETCHERPAY_PASSWORD']!);
|
||||
@@ -31,25 +32,33 @@ function getSmtpTransport(account: Account = 'yahoo') {
|
||||
host: 'smtp.gmail.com',
|
||||
port: 587,
|
||||
secure: false,
|
||||
auth: {
|
||||
user: process.env['GMAIL_EMAIL']!,
|
||||
pass: process.env['GMAIL_APP_PASSWORD']!,
|
||||
},
|
||||
auth: { user: process.env['GMAIL_EMAIL']!, pass: process.env['GMAIL_APP_PASSWORD']! },
|
||||
});
|
||||
default:
|
||||
return nodemailer.createTransport({
|
||||
host: 'smtp.mail.yahoo.com',
|
||||
port: 587,
|
||||
secure: false,
|
||||
auth: {
|
||||
user: process.env['YAHOO_EMAIL']!,
|
||||
pass: process.env['YAHOO_APP_PASSWORD']!,
|
||||
},
|
||||
auth: { user: process.env['YAHOO_EMAIL']!, pass: process.env['YAHOO_APP_PASSWORD']! },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function getSenderEmail(account: Account = 'yahoo'): string {
|
||||
function resolveSmtpTransport(ctx: EmailCtx) {
|
||||
if (typeof ctx === 'object') {
|
||||
return nodemailer.createTransport({
|
||||
host: ctx.smtpHost ?? ctx.host,
|
||||
port: ctx.smtpPort ?? 587,
|
||||
secure: false,
|
||||
auth: { user: ctx.user, pass: ctx.password },
|
||||
tls: { rejectUnauthorized: false },
|
||||
});
|
||||
}
|
||||
return getEnvSmtpTransport(ctx);
|
||||
}
|
||||
|
||||
function resolveSenderEmail(ctx: EmailCtx): string {
|
||||
if (typeof ctx === 'object') return ctx.user;
|
||||
const emailMap: Record<Account, string> = {
|
||||
yahoo: process.env['YAHOO_EMAIL'] ?? '',
|
||||
fetcherpay: process.env['FETCHERPAY_EMAIL'] ?? '',
|
||||
@@ -59,18 +68,18 @@ function getSenderEmail(account: Account = 'yahoo'): string {
|
||||
founder: process.env['FOUNDER_EMAIL'] ?? '',
|
||||
gmail: process.env['GMAIL_EMAIL'] ?? '',
|
||||
};
|
||||
return emailMap[account] ?? '';
|
||||
return emailMap[ctx] ?? '';
|
||||
}
|
||||
|
||||
export async function sendEmail(
|
||||
to: string,
|
||||
subject: string,
|
||||
body: string,
|
||||
account: Account = 'yahoo',
|
||||
ctx: EmailCtx = 'yahoo',
|
||||
): Promise<string> {
|
||||
const transporter = getSmtpTransport(account);
|
||||
const transporter = resolveSmtpTransport(ctx);
|
||||
const info = await transporter.sendMail({
|
||||
from: getSenderEmail(account),
|
||||
from: resolveSenderEmail(ctx),
|
||||
to,
|
||||
subject,
|
||||
text: body,
|
||||
@@ -82,52 +91,52 @@ export async function createDraft(
|
||||
to: string,
|
||||
subject: string,
|
||||
body: string,
|
||||
account: Account = 'yahoo',
|
||||
ctx: EmailCtx = 'yahoo',
|
||||
): Promise<string> {
|
||||
const { ImapFlow } = await import('imapflow');
|
||||
|
||||
const fetcherpayImapBase = {
|
||||
host: process.env['FETCHERPAY_IMAP_HOST'] ?? 'mail.fetcherpay.com',
|
||||
port: parseInt(process.env['FETCHERPAY_IMAP_PORT'] ?? '30993'),
|
||||
secure: true,
|
||||
tls: { rejectUnauthorized: false },
|
||||
};
|
||||
const fetcherpayImapAccounts: Partial<Record<Account, { user: string; pass: string }>> = {
|
||||
fetcherpay: { user: process.env['FETCHERPAY_EMAIL']!, pass: process.env['FETCHERPAY_PASSWORD']! },
|
||||
garfield: { user: process.env['GARFIELD_EMAIL']!, pass: process.env['GARFIELD_PASSWORD']! },
|
||||
sales: { user: process.env['SALES_EMAIL']!, pass: process.env['SALES_PASSWORD']! },
|
||||
leads: { user: process.env['LEADS_EMAIL']!, pass: process.env['LEADS_PASSWORD']! },
|
||||
founder: { user: process.env['FOUNDER_EMAIL']!, pass: process.env['FOUNDER_PASSWORD']! },
|
||||
};
|
||||
let imapConfig;
|
||||
if (fetcherpayImapAccounts[account]) {
|
||||
imapConfig = { ...fetcherpayImapBase, auth: fetcherpayImapAccounts[account]! };
|
||||
} else if (account === 'gmail') {
|
||||
let imapConfig: any;
|
||||
if (typeof ctx === 'object') {
|
||||
imapConfig = {
|
||||
host: 'imap.gmail.com',
|
||||
port: 993,
|
||||
secure: true,
|
||||
auth: {
|
||||
user: process.env['GMAIL_EMAIL']!,
|
||||
pass: process.env['GMAIL_APP_PASSWORD']!,
|
||||
},
|
||||
host: ctx.host,
|
||||
port: ctx.port,
|
||||
secure: ctx.port === 993,
|
||||
auth: { user: ctx.user, pass: ctx.password },
|
||||
tls: { rejectUnauthorized: false },
|
||||
};
|
||||
} else {
|
||||
imapConfig = {
|
||||
host: 'imap.mail.yahoo.com',
|
||||
port: 993,
|
||||
const fetcherpayImapBase = {
|
||||
host: process.env['FETCHERPAY_IMAP_HOST'] ?? 'mail.fetcherpay.com',
|
||||
port: parseInt(process.env['FETCHERPAY_IMAP_PORT'] ?? '30993'),
|
||||
secure: true,
|
||||
auth: {
|
||||
user: process.env['YAHOO_EMAIL']!,
|
||||
pass: process.env['YAHOO_APP_PASSWORD']!,
|
||||
},
|
||||
tls: { rejectUnauthorized: false },
|
||||
};
|
||||
const fetcherpayAccounts: Partial<Record<Account, { user: string; pass: string }>> = {
|
||||
fetcherpay: { user: process.env['FETCHERPAY_EMAIL']!, pass: process.env['FETCHERPAY_PASSWORD']! },
|
||||
garfield: { user: process.env['GARFIELD_EMAIL']!, pass: process.env['GARFIELD_PASSWORD']! },
|
||||
sales: { user: process.env['SALES_EMAIL']!, pass: process.env['SALES_PASSWORD']! },
|
||||
leads: { user: process.env['LEADS_EMAIL']!, pass: process.env['LEADS_PASSWORD']! },
|
||||
founder: { user: process.env['FOUNDER_EMAIL']!, pass: process.env['FOUNDER_PASSWORD']! },
|
||||
};
|
||||
if (fetcherpayAccounts[ctx]) {
|
||||
imapConfig = { ...fetcherpayImapBase, auth: fetcherpayAccounts[ctx]! };
|
||||
} else if (ctx === 'gmail') {
|
||||
imapConfig = {
|
||||
host: 'imap.gmail.com', port: 993, secure: true,
|
||||
auth: { user: process.env['GMAIL_EMAIL']!, pass: process.env['GMAIL_APP_PASSWORD']! },
|
||||
};
|
||||
} else {
|
||||
imapConfig = {
|
||||
host: 'imap.mail.yahoo.com', port: 993, secure: true,
|
||||
auth: { user: process.env['YAHOO_EMAIL']!, pass: process.env['YAHOO_APP_PASSWORD']! },
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const client = new ImapFlow(imapConfig);
|
||||
await client.connect();
|
||||
|
||||
const from = getSenderEmail(account);
|
||||
const from = resolveSenderEmail(ctx);
|
||||
const rawMessage = [
|
||||
`From: ${from}`,
|
||||
`To: ${to}`,
|
||||
|
||||
205
src/tools.test.ts
Normal file
205
src/tools.test.ts
Normal file
@@ -0,0 +1,205 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
|
||||
// Mock all I/O dependencies before importing handleToolCall
|
||||
vi.mock('./billing/usage.js', () => ({
|
||||
recordUsage: vi.fn().mockResolvedValue(undefined),
|
||||
checkLimit: vi.fn().mockResolvedValue({ allowed: true, limit: 1000, used: 5 }),
|
||||
}));
|
||||
vi.mock('./db.js', () => ({ getPool: vi.fn() }));
|
||||
vi.mock('./imap.js', () => ({
|
||||
searchMessages: vi.fn().mockResolvedValue([]),
|
||||
readMessage: vi.fn().mockResolvedValue({}),
|
||||
getProfile: vi.fn().mockResolvedValue({ email: 'test@example.com' }),
|
||||
listFolders: vi.fn().mockResolvedValue([]),
|
||||
}));
|
||||
vi.mock('./smtp.js', () => ({
|
||||
sendEmail: vi.fn().mockResolvedValue({ messageId: 'abc' }),
|
||||
createDraft: vi.fn().mockResolvedValue({ id: '1' }),
|
||||
}));
|
||||
vi.mock('./clients/obsidian.js', () => ({
|
||||
searchNotes: vi.fn().mockResolvedValue([]),
|
||||
getNote: vi.fn().mockResolvedValue({ content: '' }),
|
||||
appendToNote: vi.fn().mockResolvedValue({}),
|
||||
updateNote: vi.fn().mockResolvedValue({}),
|
||||
getSyncStatus: vi.fn().mockResolvedValue({}),
|
||||
}));
|
||||
vi.mock('./clients/whatsapp.js', () => ({
|
||||
sendMessage: vi.fn().mockResolvedValue({}),
|
||||
sendTemplate: vi.fn().mockResolvedValue({}),
|
||||
getMessageStatus: vi.fn().mockResolvedValue({}),
|
||||
listTemplates: vi.fn().mockResolvedValue([]),
|
||||
}));
|
||||
vi.mock('./clients/linkedin.js', () => ({
|
||||
getProfile: vi.fn().mockResolvedValue({}),
|
||||
createPost: vi.fn().mockResolvedValue({}),
|
||||
createVideoPost: vi.fn().mockResolvedValue({}),
|
||||
searchConnections: vi.fn().mockResolvedValue([]),
|
||||
sendMessage: vi.fn().mockResolvedValue({}),
|
||||
}));
|
||||
vi.mock('./clients/telegram.js', () => ({
|
||||
getMe: vi.fn().mockResolvedValue({}),
|
||||
sendMessage: vi.fn().mockResolvedValue({}),
|
||||
sendPhoto: vi.fn().mockResolvedValue({}),
|
||||
getUpdates: vi.fn().mockResolvedValue([]),
|
||||
getChat: vi.fn().mockResolvedValue({}),
|
||||
}));
|
||||
vi.mock('./clients/discord.js', () => ({
|
||||
getMe: vi.fn().mockResolvedValue({}),
|
||||
getGuilds: vi.fn().mockResolvedValue([]),
|
||||
getChannels: vi.fn().mockResolvedValue([]),
|
||||
sendMessage: vi.fn().mockResolvedValue({}),
|
||||
getMessages: vi.fn().mockResolvedValue([]),
|
||||
}));
|
||||
vi.mock('./clients/instagram.js', () => ({
|
||||
getProfile: vi.fn().mockResolvedValue({}),
|
||||
getMedia: vi.fn().mockResolvedValue([]),
|
||||
createImagePost: vi.fn().mockResolvedValue({}),
|
||||
createReel: vi.fn().mockResolvedValue({}),
|
||||
}));
|
||||
vi.mock('./clients/twitter.js', () => ({
|
||||
searchTweets: vi.fn().mockResolvedValue([]),
|
||||
getUserProfile: vi.fn().mockResolvedValue({}),
|
||||
getUserTweets: vi.fn().mockResolvedValue([]),
|
||||
createTweet: vi.fn().mockResolvedValue({}),
|
||||
uploadVideoAndTweet: vi.fn().mockResolvedValue({}),
|
||||
}));
|
||||
vi.mock('./clients/tiktok.js', () => ({
|
||||
getUserProfile: vi.fn().mockResolvedValue({}),
|
||||
getCreatorInfo: vi.fn().mockResolvedValue({}),
|
||||
createVideo: vi.fn().mockResolvedValue({}),
|
||||
getVideoStatus: vi.fn().mockResolvedValue({}),
|
||||
}));
|
||||
vi.mock('./clients/snapchat.js', () => ({
|
||||
getMe: vi.fn().mockResolvedValue({}),
|
||||
createSnap: vi.fn().mockResolvedValue({}),
|
||||
getAdAccounts: vi.fn().mockResolvedValue([]),
|
||||
}));
|
||||
vi.mock('./clients/facebook.js', () => ({
|
||||
getPage: vi.fn().mockResolvedValue({}),
|
||||
getPosts: vi.fn().mockResolvedValue([]),
|
||||
createPost: vi.fn().mockResolvedValue({}),
|
||||
createPhotoPost: vi.fn().mockResolvedValue({}),
|
||||
createVideoPost: vi.fn().mockResolvedValue({}),
|
||||
}));
|
||||
vi.mock('./redis.js', () => ({ default: { get: vi.fn(), set: vi.fn(), del: vi.fn() } }));
|
||||
vi.mock('./multitenancy/credential-store.js', () => ({
|
||||
getCredential: vi.fn().mockResolvedValue(null),
|
||||
}));
|
||||
|
||||
import { handleToolCall, stripAccountParam, tools } from './tools.js';
|
||||
import { checkLimit, recordUsage } from './billing/usage.js';
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.mocked(checkLimit).mockResolvedValue({ allowed: true, limit: 1000, used: 5 });
|
||||
});
|
||||
|
||||
const mockCustomer = {
|
||||
id: 'cust-123',
|
||||
plan: 'growth' as const,
|
||||
active: true,
|
||||
email: 'test@example.com',
|
||||
getCredential: vi.fn().mockResolvedValue(null),
|
||||
};
|
||||
|
||||
describe('handleToolCall — plan limit gate', () => {
|
||||
|
||||
it('returns isError when customer is over limit', async () => {
|
||||
vi.mocked(checkLimit).mockResolvedValue({ allowed: false, limit: 1000, used: 1000 });
|
||||
const result = await handleToolCall('get_profile', {}, mockCustomer);
|
||||
expect(result.isError).toBe(true);
|
||||
expect(result.content[0].text).toMatch(/limit/i);
|
||||
});
|
||||
|
||||
it('does not check limit when no customer (unauthenticated dev mode)', async () => {
|
||||
await handleToolCall('get_profile', {});
|
||||
expect(checkLimit).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('proceeds normally when under limit', async () => {
|
||||
const result = await handleToolCall('get_profile', {}, mockCustomer);
|
||||
expect(result.isError).toBeUndefined();
|
||||
expect(checkLimit).toHaveBeenCalledWith('cust-123', 'growth');
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleToolCall — platform attribution', () => {
|
||||
|
||||
it.each([
|
||||
['send_email', 'email'],
|
||||
['create_draft', 'email'],
|
||||
['search_messages', 'email'],
|
||||
['get_profile', 'email'],
|
||||
['list_folders', 'email'],
|
||||
['yahoo_send_email', 'email'],
|
||||
['linkedin_create_post', 'linkedin'],
|
||||
['obsidian_search_notes', 'obsidian'],
|
||||
['tiktok_create_video', 'tiktok'],
|
||||
['whatsapp_send_message', 'whatsapp'],
|
||||
['telegram_send_message', 'telegram'],
|
||||
['discord_send_message', 'discord'],
|
||||
['instagram_create_post', 'instagram'],
|
||||
['twitter_create_tweet', 'twitter'],
|
||||
['snapchat_create_snap', 'snapchat'],
|
||||
['facebook_create_post', 'facebook'],
|
||||
])('%s → platform "%s"', async (toolName, expectedPlatform) => {
|
||||
await handleToolCall(toolName, {}, mockCustomer).catch(() => {});
|
||||
const calls = vi.mocked(recordUsage).mock.calls;
|
||||
const last = calls[calls.length - 1];
|
||||
if (last) {
|
||||
expect(last[1]).toBe(expectedPlatform);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleToolCall — error handling', () => {
|
||||
it('returns isError: true on tool exception', async () => {
|
||||
const { getProfile } = await import('./imap.js');
|
||||
vi.mocked(getProfile).mockRejectedValueOnce(new Error('IMAP connection refused'));
|
||||
const result = await handleToolCall('get_profile', {});
|
||||
expect(result.isError).toBe(true);
|
||||
expect(result.content[0].text).toContain('IMAP connection refused');
|
||||
});
|
||||
});
|
||||
|
||||
describe('stripAccountParam', () => {
|
||||
it('removes account from inputSchema.properties', () => {
|
||||
const tool = tools.find(t => 'account' in (t.inputSchema.properties ?? {}))!;
|
||||
expect(tool).toBeDefined();
|
||||
const stripped = stripAccountParam(tool);
|
||||
expect(stripped.inputSchema.properties).not.toHaveProperty('account');
|
||||
});
|
||||
|
||||
it('does not mutate the original tool', () => {
|
||||
const tool = tools.find(t => 'account' in (t.inputSchema.properties ?? {}))!;
|
||||
const before = Object.keys(tool.inputSchema.properties ?? {});
|
||||
stripAccountParam(tool);
|
||||
const after = Object.keys(tool.inputSchema.properties ?? {});
|
||||
expect(after).toEqual(before);
|
||||
});
|
||||
|
||||
it('preserves all other properties', () => {
|
||||
const tool = tools.find(t => 'account' in (t.inputSchema.properties ?? {}))!;
|
||||
const stripped = stripAccountParam(tool);
|
||||
const originalWithoutAccount = Object.fromEntries(
|
||||
Object.entries(tool.inputSchema.properties ?? {}).filter(([k]) => k !== 'account')
|
||||
);
|
||||
expect(stripped.inputSchema.properties).toEqual(originalWithoutAccount);
|
||||
expect(stripped.name).toBe(tool.name);
|
||||
expect(stripped.description).toBe(tool.description);
|
||||
});
|
||||
|
||||
it('handles tools with no account param safely', () => {
|
||||
const tool = tools.find(t => !('account' in (t.inputSchema.properties ?? {})))!;
|
||||
expect(tool).toBeDefined();
|
||||
const stripped = stripAccountParam(tool);
|
||||
expect(stripped.inputSchema.properties).toEqual(tool.inputSchema.properties);
|
||||
});
|
||||
|
||||
it('all tools in multi-tenant mode have no account param', () => {
|
||||
const stripped = tools.map(stripAccountParam);
|
||||
for (const t of stripped) {
|
||||
expect(t.inputSchema.properties).not.toHaveProperty('account');
|
||||
}
|
||||
});
|
||||
});
|
||||
88
src/tools.ts
88
src/tools.ts
@@ -1,8 +1,9 @@
|
||||
import { Tool } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { Customer } from './billing/middleware.js';
|
||||
import { recordUsage } from './billing/usage.js';
|
||||
import { searchMessages, readMessage, getProfile, listFolders, type Account } from './imap.js';
|
||||
import { recordUsage, checkLimit } from './billing/usage.js';
|
||||
import { searchMessages, readMessage, getProfile, listFolders, type Account, type EmailCtx } from './imap.js';
|
||||
import { sendEmail, createDraft } from './smtp.js';
|
||||
import type { EmailCredentials } from './multitenancy/credential-store.js';
|
||||
import { searchNotes, getNote, appendToNote, updateNote, getSyncStatus } from './clients/obsidian.js';
|
||||
import { sendMessage, sendTemplate, getMessageStatus, listTemplates } from './clients/whatsapp.js';
|
||||
import { getProfile as getLinkedInProfile, createPost as createLinkedInPost, createVideoPost as createLinkedInVideoPost, searchConnections, sendMessage as sendLinkedInMessage } from './clients/linkedin.js';
|
||||
@@ -727,40 +728,81 @@ function acct(args: Record<string, unknown>): Account {
|
||||
return (args.account as Account) ?? 'yahoo';
|
||||
}
|
||||
|
||||
async function resolveEmailCtx(args: Record<string, unknown>, customer?: Customer): Promise<EmailCtx> {
|
||||
if (customer) {
|
||||
const creds = await customer.getCredential<EmailCredentials>('email');
|
||||
if (creds) return creds;
|
||||
}
|
||||
return acct(args);
|
||||
}
|
||||
|
||||
const PLATFORM_PREFIXES = [
|
||||
'linkedin', 'obsidian', 'whatsapp', 'telegram', 'discord',
|
||||
'instagram', 'twitter', 'tiktok', 'snapchat', 'facebook',
|
||||
];
|
||||
|
||||
function toolPlatform(name: string): string {
|
||||
const prefix = name.split('_')[0];
|
||||
return PLATFORM_PREFIXES.includes(prefix) ? prefix : 'email';
|
||||
}
|
||||
|
||||
export async function handleToolCall(
|
||||
name: string,
|
||||
args: Record<string, unknown>,
|
||||
customer?: Customer
|
||||
): Promise<{ content: Array<{ type: string; text: string }> }> {
|
||||
): Promise<{ content: Array<{ type: string; text: string }>; isError?: boolean }> {
|
||||
console.log(`[tool] ${name}`, JSON.stringify(args));
|
||||
const t0 = Date.now();
|
||||
|
||||
if (customer) {
|
||||
const { allowed } = await checkLimit(customer.id, customer.plan);
|
||||
if (!allowed) {
|
||||
return {
|
||||
content: [{ type: 'text', text: 'Monthly tool call limit reached. Please upgrade your plan.' }],
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
let result: unknown;
|
||||
|
||||
switch (name) {
|
||||
case 'get_profile':
|
||||
result = await getProfile(acct(args));
|
||||
case 'get_profile': {
|
||||
const emailCtx = await resolveEmailCtx(args, customer);
|
||||
result = await getProfile(emailCtx);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'search_messages':
|
||||
result = await searchMessages(args.q as string, (args.maxResults as number) ?? 20, acct(args), args.folder as string | undefined);
|
||||
case 'search_messages': {
|
||||
const emailCtx = await resolveEmailCtx(args, customer);
|
||||
result = await searchMessages(args.q as string, (args.maxResults as number) ?? 20, emailCtx, args.folder as string | undefined);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'read_message':
|
||||
result = await readMessage(args.uid as number, acct(args), args.folder as string | undefined);
|
||||
case 'read_message': {
|
||||
const emailCtx = await resolveEmailCtx(args, customer);
|
||||
result = await readMessage(args.uid as number, emailCtx, args.folder as string | undefined);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'list_folders':
|
||||
result = await listFolders(acct(args));
|
||||
case 'list_folders': {
|
||||
const emailCtx = await resolveEmailCtx(args, customer);
|
||||
result = await listFolders(emailCtx);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'create_draft':
|
||||
result = await createDraft(args.to as string, args.subject as string, args.body as string, acct(args));
|
||||
case 'create_draft': {
|
||||
const emailCtx = await resolveEmailCtx(args, customer);
|
||||
result = await createDraft(args.to as string, args.subject as string, args.body as string, emailCtx);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'send_email':
|
||||
result = await sendEmail(args.to as string, args.subject as string, args.body as string, acct(args));
|
||||
case 'send_email': {
|
||||
const emailCtx = await resolveEmailCtx(args, customer);
|
||||
result = await sendEmail(args.to as string, args.subject as string, args.body as string, emailCtx);
|
||||
break;
|
||||
}
|
||||
|
||||
// ── Obsidian ──────────────────────────────────────────────────────────
|
||||
case 'obsidian_search_notes':
|
||||
@@ -1126,8 +1168,7 @@ export async function handleToolCall(
|
||||
|
||||
console.log(`[tool] ${name} OK (${Date.now() - t0}ms)`);
|
||||
if (customer) {
|
||||
const platform = name.split('_')[0];
|
||||
recordUsage(customer.id, platform, name).catch(() => {});
|
||||
recordUsage(customer.id, toolPlatform(name), name).catch(() => {});
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text', text: JSON.stringify(result, null, 2) }],
|
||||
@@ -1139,6 +1180,19 @@ export async function handleToolCall(
|
||||
console.error(stack);
|
||||
return {
|
||||
content: [{ type: 'text', text: `Error: ${msg}` }],
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export function stripAccountParam(tool: Tool): Tool {
|
||||
const props = tool.inputSchema.properties ?? {};
|
||||
const filtered = Object.fromEntries(Object.entries(props).filter(([k]) => k !== 'account'));
|
||||
return {
|
||||
...tool,
|
||||
inputSchema: {
|
||||
...tool.inputSchema,
|
||||
properties: filtered as Tool['inputSchema']['properties'],
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
161
src/webhooks/delivery.test.ts
Normal file
161
src/webhooks/delivery.test.ts
Normal file
@@ -0,0 +1,161 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
|
||||
const { mockQuery, mockRPush, mockExpire } = vi.hoisted(() => ({
|
||||
mockQuery: vi.fn(),
|
||||
mockRPush: vi.fn(),
|
||||
mockExpire: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../db.js', () => ({ getPool: vi.fn(() => ({ query: mockQuery })) }));
|
||||
vi.mock('../redis.js', () => ({
|
||||
default: { rPush: mockRPush, expire: mockExpire },
|
||||
}));
|
||||
|
||||
global.fetch = vi.fn();
|
||||
|
||||
import { deliverWebhook, isValidWebhookUrl } from './delivery.js';
|
||||
|
||||
// ── URL validation ──────────────────────────────────────────────
|
||||
|
||||
describe('isValidWebhookUrl', () => {
|
||||
it('accepts https:// with public hostname', () => {
|
||||
expect(isValidWebhookUrl('https://my-server.example.com/hook')).toBe(true);
|
||||
});
|
||||
|
||||
it('rejects ftp:// and other non-http schemes', () => {
|
||||
expect(isValidWebhookUrl('ftp://example.com/hook')).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects 127.x loopback', () => {
|
||||
expect(isValidWebhookUrl('http://127.0.0.1/hook')).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects 10.x private range', () => {
|
||||
expect(isValidWebhookUrl('https://10.0.0.1/hook')).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects 192.168.x private range', () => {
|
||||
expect(isValidWebhookUrl('https://192.168.1.1/hook')).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects 172.16-31.x private range', () => {
|
||||
expect(isValidWebhookUrl('https://172.16.0.1/hook')).toBe(false);
|
||||
expect(isValidWebhookUrl('https://172.31.255.255/hook')).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects localhost hostname', () => {
|
||||
expect(isValidWebhookUrl('http://localhost/hook')).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects .local domains', () => {
|
||||
expect(isValidWebhookUrl('http://myserver.local/hook')).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects invalid URLs', () => {
|
||||
expect(isValidWebhookUrl('not-a-url')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
// ── deliverWebhook ──────────────────────────────────────────────
|
||||
|
||||
describe('deliverWebhook', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.useFakeTimers();
|
||||
mockRPush.mockResolvedValue(1);
|
||||
mockExpire.mockResolvedValue(1);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('does nothing when customer has no webhook_url', async () => {
|
||||
mockQuery.mockResolvedValue([[{ webhook_url: null, webhook_secret: null }]]);
|
||||
await deliverWebhook('cust-1', 'whatsapp', 'inbound_message', { from: '+1234' });
|
||||
expect(fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does nothing when customer not found', async () => {
|
||||
mockQuery.mockResolvedValue([[]]);
|
||||
await deliverWebhook('cust-1', 'whatsapp', 'inbound_message', { from: '+1234' });
|
||||
expect(fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('POSTs to webhook_url with correct headers on success', async () => {
|
||||
mockQuery.mockResolvedValue([[{
|
||||
webhook_url: 'https://example.com/hook',
|
||||
webhook_secret: 'secret123',
|
||||
}]]);
|
||||
(fetch as ReturnType<typeof vi.fn>).mockResolvedValue({ ok: true });
|
||||
|
||||
await deliverWebhook('cust-1', 'whatsapp', 'inbound_message', { from: '+1234', text: 'hi' });
|
||||
|
||||
expect(fetch).toHaveBeenCalledTimes(1);
|
||||
const [url, opts] = (fetch as ReturnType<typeof vi.fn>).mock.calls[0] as [string, RequestInit];
|
||||
expect(url).toBe('https://example.com/hook');
|
||||
expect((opts.headers as Record<string, string>)['Content-Type']).toBe('application/json');
|
||||
expect((opts.headers as Record<string, string>)['X-SquareMCP-Signature']).toMatch(/^sha256=[0-9a-f]{64}$/);
|
||||
});
|
||||
|
||||
it('sends correct HMAC signature', async () => {
|
||||
const secret = 'mysecret';
|
||||
mockQuery.mockResolvedValue([[{ webhook_url: 'https://example.com/hook', webhook_secret: secret }]]);
|
||||
(fetch as ReturnType<typeof vi.fn>).mockResolvedValue({ ok: true });
|
||||
|
||||
await deliverWebhook('cust-1', 'whatsapp', 'inbound_message', { from: '+1' });
|
||||
|
||||
const [, opts] = (fetch as ReturnType<typeof vi.fn>).mock.calls[0] as [string, RequestInit];
|
||||
const sig = (opts.headers as Record<string, string>)['X-SquareMCP-Signature'];
|
||||
const body = opts.body as string;
|
||||
|
||||
// Verify the signature independently
|
||||
const { createHmac } = await import('crypto');
|
||||
const expected = `sha256=${createHmac('sha256', secret).update(body).digest('hex')}`;
|
||||
expect(sig).toBe(expected);
|
||||
});
|
||||
|
||||
it('retries on failure and pushes to DLQ after all attempts', async () => {
|
||||
mockQuery.mockResolvedValue([[{ webhook_url: 'https://example.com/hook', webhook_secret: 'sec' }]]);
|
||||
(fetch as ReturnType<typeof vi.fn>).mockResolvedValue({ ok: false, status: 500 });
|
||||
|
||||
const deliverPromise = deliverWebhook('cust-1', 'whatsapp', 'inbound_message', { from: '+1' });
|
||||
|
||||
// Advance timers through all retry delays: 1s, 4s, 16s
|
||||
await vi.advanceTimersByTimeAsync(1000);
|
||||
await vi.advanceTimersByTimeAsync(4000);
|
||||
await vi.advanceTimersByTimeAsync(16000);
|
||||
await deliverPromise;
|
||||
|
||||
expect(fetch).toHaveBeenCalledTimes(4); // 1 initial + 3 retries
|
||||
expect(mockRPush).toHaveBeenCalledWith(
|
||||
'webhook:dlq:cust-1',
|
||||
expect.stringContaining('"customerId":"cust-1"')
|
||||
);
|
||||
expect(mockExpire).toHaveBeenCalledWith('webhook:dlq:cust-1', 604800);
|
||||
});
|
||||
|
||||
it('does not push to DLQ on first-attempt success', async () => {
|
||||
mockQuery.mockResolvedValue([[{ webhook_url: 'https://example.com/hook', webhook_secret: 'sec' }]]);
|
||||
(fetch as ReturnType<typeof vi.fn>).mockResolvedValue({ ok: true });
|
||||
|
||||
await deliverWebhook('cust-1', 'whatsapp', 'inbound_message', { from: '+1' });
|
||||
|
||||
expect(mockRPush).not.toHaveBeenCalled();
|
||||
expect(mockExpire).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('succeeds on second attempt (one initial failure)', async () => {
|
||||
mockQuery.mockResolvedValue([[{ webhook_url: 'https://example.com/hook', webhook_secret: 'sec' }]]);
|
||||
(fetch as ReturnType<typeof vi.fn>)
|
||||
.mockResolvedValueOnce({ ok: false, status: 503 })
|
||||
.mockResolvedValueOnce({ ok: true });
|
||||
|
||||
const deliverPromise = deliverWebhook('cust-1', 'whatsapp', 'inbound_message', { from: '+1' });
|
||||
await vi.advanceTimersByTimeAsync(1000);
|
||||
await deliverPromise;
|
||||
|
||||
expect(fetch).toHaveBeenCalledTimes(2);
|
||||
expect(mockRPush).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
84
src/webhooks/delivery.ts
Normal file
84
src/webhooks/delivery.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
import crypto from 'crypto';
|
||||
import redis from '../redis.js';
|
||||
import { getPool } from '../db.js';
|
||||
|
||||
const RETRY_DELAYS_MS = [1000, 4000, 16000];
|
||||
const DLQ_TTL_SECONDS = 604800; // 7 days
|
||||
|
||||
export interface WebhookPayload {
|
||||
customerId: string;
|
||||
platform: string;
|
||||
event: string;
|
||||
data: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export function isValidWebhookUrl(url: string): boolean {
|
||||
try {
|
||||
const parsed = new URL(url);
|
||||
if (!['http:', 'https:'].includes(parsed.protocol)) return false;
|
||||
const host = parsed.hostname;
|
||||
if (host === 'localhost' || host === '0.0.0.0' || host.endsWith('.local')) return false;
|
||||
// Block RFC-1918 private ranges and loopback
|
||||
if (/^(127\.|10\.|192\.168\.|172\.(1[6-9]|2\d|3[01])\.)/.test(host)) return false;
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function signPayload(secret: string, payload: string): string {
|
||||
return `sha256=${crypto.createHmac('sha256', secret).update(payload).digest('hex')}`;
|
||||
}
|
||||
|
||||
async function postWithRetry(url: string, payload: string, signature: string): Promise<boolean> {
|
||||
for (let attempt = 0; attempt <= RETRY_DELAYS_MS.length; attempt++) {
|
||||
if (attempt > 0) {
|
||||
await new Promise(r => setTimeout(r, RETRY_DELAYS_MS[attempt - 1]));
|
||||
}
|
||||
try {
|
||||
const res = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json', 'X-SquareMCP-Signature': signature },
|
||||
body: payload,
|
||||
signal: AbortSignal.timeout(10_000),
|
||||
});
|
||||
if (res.ok) return true;
|
||||
console.warn(`[webhook] attempt ${attempt + 1} HTTP ${res.status} → ${url}`);
|
||||
} catch (err) {
|
||||
console.warn(`[webhook] attempt ${attempt + 1} error:`, (err as Error).message);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
export async function deliverWebhook(
|
||||
customerId: string,
|
||||
platform: string,
|
||||
event: string,
|
||||
data: Record<string, unknown>
|
||||
): Promise<void> {
|
||||
const [rows] = await getPool().query<any[]>(
|
||||
'SELECT webhook_url, webhook_secret FROM customers WHERE id = ?',
|
||||
[customerId]
|
||||
);
|
||||
if (!rows.length || !rows[0].webhook_url || !rows[0].webhook_secret) return;
|
||||
|
||||
const { webhook_url, webhook_secret } = rows[0] as { webhook_url: string; webhook_secret: string };
|
||||
const payload: WebhookPayload = { customerId, platform, event, data };
|
||||
const payloadStr = JSON.stringify(payload);
|
||||
const signature = signPayload(webhook_secret, payloadStr);
|
||||
|
||||
const delivered = await postWithRetry(webhook_url, payloadStr, signature);
|
||||
|
||||
if (!delivered) {
|
||||
console.error(`[webhook] all attempts failed for customer ${customerId}, pushing to DLQ`);
|
||||
const dlqKey = `webhook:dlq:${customerId}`;
|
||||
const entry = JSON.stringify({
|
||||
payload,
|
||||
failedAt: new Date().toISOString(),
|
||||
attempts: RETRY_DELAYS_MS.length + 1,
|
||||
});
|
||||
await redis.rPush(dlqKey, entry);
|
||||
await redis.expire(dlqKey, DLQ_TTL_SECONDS);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user