Add routing and summary models and start with a smaller model for json patches with a bigger fallback model immediately.

This commit is contained in:
khalid@traclabs.com
2026-04-23 09:10:52 -05:00
parent c17ce052c1
commit 3cb0cbe088
2 changed files with 35 additions and 10 deletions

View File

@@ -5,8 +5,11 @@ API_EDIT_SECRET=change-me-to-a-random-string
OLLAMA_API_KEY=
# For Ollama Cloud use https://ollama.com, for local Ollama use http://localhost:11434
OLLAMA_HOST=https://ollama.com
OLLAMA_MODEL=qwen3.5:397b-cloud
OLLAMA_INTENT_MODEL=gemma4:31b-cloud
OLLAMA_ROUTING_MODEL=gemma4:31b-cloud
OLLAMA_SUMMARY_MODEL=gemma4:31b-cloud
OLLAMA_JSON_MODEL=gemma4:31b-cloud
OLLAMA_JSON_FALLBACK_MODEL=qwen3.5:397b-cloud
OLLAMA_FALLBACK_MODEL=gpt-oss:120b
# Paths

View File

@@ -4,9 +4,12 @@ import { logger } from '../logger.js';
const OLLAMA_HOST = process.env.OLLAMA_HOST || 'http://localhost:11434';
const OLLAMA_API_KEY = process.env.OLLAMA_API_KEY || '';
const PRIMARY_MODEL = process.env.OLLAMA_MODEL || 'qwen3.5:397b-cloud';
const FALLBACK_MODEL = process.env.OLLAMA_FALLBACK_MODEL || 'gpt-oss:120b';
const INTENT_MODEL = process.env.OLLAMA_INTENT_MODEL || 'gemma4:31b-cloud';
const ROUTING_MODEL = process.env.OLLAMA_ROUTING_MODEL || 'gemma4:31b-cloud';
const SUMMARY_MODEL = process.env.OLLAMA_SUMMARY_MODEL || 'gemma4:31b-cloud';
const JSON_MODEL = process.env.OLLAMA_JSON_MODEL || 'qwen3.5:397b-cloud';
const JSON_FALLBACK_MODEL = process.env.OLLAMA_JSON_FALLBACK_MODEL || FALLBACK_MODEL;
const MAX_RETRIES = 3;
export interface LlmChatCaller {
@@ -41,13 +44,20 @@ async function generateWithValidation<T>(params: {
schema: z.ZodType<T>;
chat?: LlmChatCaller;
models?: string[];
maxRetries?: number;
maxRetriesByModel?: number[];
}): Promise<T> {
const chat = params.chat || ollamaChat;
const models = params.models?.length ? params.models : [PRIMARY_MODEL, FALLBACK_MODEL];
const models = params.models?.length ? params.models : [JSON_MODEL, JSON_FALLBACK_MODEL];
for (const model of models) {
for (const [modelIndex, model] of models.entries()) {
const msgs = [...params.messages];
for (let attempt = 0; attempt < MAX_RETRIES; attempt++) {
const maxRetries =
params.maxRetriesByModel?.[modelIndex] ??
params.maxRetries ??
MAX_RETRIES;
for (let attempt = 0; attempt < maxRetries; attempt++) {
logger.debug({ event: 'llm.request', model, attempt }, 'LLM call');
try {
const raw = await chat(msgs, model);
@@ -69,7 +79,7 @@ async function generateWithValidation<T>(params: {
);
} catch (err) {
logger.warn({ event: 'llm.retry', model, attempt, error: (err as Error).message }, 'LLM call or parse failed');
if (attempt === MAX_RETRIES - 1) break;
if (attempt === maxRetries - 1) break;
msgs.push(
{ role: 'user', content: `Your response was not valid JSON. Please respond with ONLY a JSON object, no markdown or extra text.` }
);
@@ -112,7 +122,14 @@ The output must be valid JSON matching the exact same schema as the input.`,
},
];
return generateWithValidation({ messages, schema: params.schema, chat });
return generateWithValidation({
messages,
schema: params.schema,
chat,
models: [JSON_MODEL, JSON_FALLBACK_MODEL],
// Switch to fallback immediately after one failure on the primary JSON model.
maxRetriesByModel: [1, MAX_RETRIES],
});
}
export interface RouteEditIntentParams {
@@ -147,7 +164,12 @@ Example response:
},
];
const routed = await generateWithValidation({ messages, schema: routingOutputSchema, chat });
const routed = await generateWithValidation({
messages,
schema: routingOutputSchema,
chat,
models: [ROUTING_MODEL, FALLBACK_MODEL],
});
// Some TS setups infer optional fields from the Zod schema; normalize to our contract type.
return {
...routed,
@@ -233,7 +255,7 @@ Rules:
];
try {
const result = await chatFn(messages, PRIMARY_MODEL);
const result = await chatFn(messages, SUMMARY_MODEL);
return result.trim().slice(0, 320);
} catch {
// Fallback: generate a basic listing from the manifest
@@ -319,7 +341,7 @@ Bad examples (too technical — never do this):
];
try {
const result = await chat(messages, PRIMARY_MODEL);
const result = await chat(messages, SUMMARY_MODEL);
return result.replace(/["'`]/g, '').trim().slice(0, 280);
} catch {
// Fallback: echo back the user's original request