diff --git a/lib/ai-client/chat.ts b/lib/ai-client/chat.ts index 3241a27..f869c8f 100644 --- a/lib/ai-client/chat.ts +++ b/lib/ai-client/chat.ts @@ -1,10 +1,7 @@ import { generateText } from "ai"; import type { LanguageModelUsage, ModelMessage } from "ai"; -import { createAnthropic } from "@ai-sdk/anthropic"; -import { createGoogleGenerativeAI } from "@ai-sdk/google"; -import { createOpenAI } from "@ai-sdk/openai"; -import type { ProviderConfig, ProviderProtocol } from "@infiplot/types"; -import { normalizeBaseUrl } from "./normalizeUrl"; +import type { ProviderConfig } from "@infiplot/types"; +import { createLanguageModel, resolveProtocol } from "./model"; export type ChatMessage = { role: "system" | "user" | "assistant"; @@ -31,24 +28,6 @@ function summarizeSdkUsage( return `[cache] ${tag} input=${input} completion=${output} (provider didn't report cache stats)`; } -function resolveTextProtocol(config: ProviderConfig): ProviderProtocol { - return config.provider ?? "openai_compatible"; -} - -function createLanguageModel(config: ProviderConfig, protocol: ProviderProtocol) { - const baseURL = normalizeBaseUrl(config.baseUrl, protocol); - switch (protocol) { - case "anthropic": - return createAnthropic({ apiKey: config.apiKey, baseURL })(config.model); - case "google": - return createGoogleGenerativeAI({ apiKey: config.apiKey, baseURL })(config.model); - case "openai_compatible": - case "openai": - default: - return createOpenAI({ apiKey: config.apiKey, baseURL }).chat(config.model); - } -} - export async function chat( config: ProviderConfig, messages: ChatMessage[], @@ -57,7 +36,7 @@ export async function chat( tag?: string; }, ): Promise { - const protocol = resolveTextProtocol(config); + const protocol = resolveProtocol(config); const model = createLanguageModel(config, protocol); const system = messages.find((m) => m.role === "system")?.content; diff --git a/lib/ai-client/model.ts b/lib/ai-client/model.ts new file mode 100644 index 0000000..155e424 --- /dev/null +++ b/lib/ai-client/model.ts @@ -0,0 +1,23 @@ +import { createAnthropic } from "@ai-sdk/anthropic"; +import { createGoogleGenerativeAI } from "@ai-sdk/google"; +import { createOpenAI } from "@ai-sdk/openai"; +import type { ProviderConfig, ProviderProtocol } from "@infiplot/types"; +import { normalizeBaseUrl } from "./normalizeUrl"; + +export function resolveProtocol(config: ProviderConfig): ProviderProtocol { + return config.provider ?? "openai_compatible"; +} + +export function createLanguageModel(config: ProviderConfig, protocol: ProviderProtocol) { + const baseURL = normalizeBaseUrl(config.baseUrl, protocol); + switch (protocol) { + case "anthropic": + return createAnthropic({ apiKey: config.apiKey, baseURL })(config.model); + case "google": + return createGoogleGenerativeAI({ apiKey: config.apiKey, baseURL })(config.model); + case "openai_compatible": + case "openai": + default: + return createOpenAI({ apiKey: config.apiKey, baseURL }).chat(config.model); + } +} diff --git a/lib/ai-client/vision.ts b/lib/ai-client/vision.ts index 8583180..12df0fa 100644 --- a/lib/ai-client/vision.ts +++ b/lib/ai-client/vision.ts @@ -1,10 +1,7 @@ import { generateText } from "ai"; import type { ModelMessage } from "ai"; -import { createAnthropic } from "@ai-sdk/anthropic"; -import { createGoogleGenerativeAI } from "@ai-sdk/google"; -import { createOpenAI } from "@ai-sdk/openai"; -import type { ProviderConfig, ProviderProtocol } from "@infiplot/types"; -import { normalizeBaseUrl } from "./normalizeUrl"; +import type { ProviderConfig } from "@infiplot/types"; +import { createLanguageModel, resolveProtocol } from "./model"; const VISION_TIMEOUT_MS = 60_000; @@ -20,32 +17,13 @@ export async function interpretClick( ); } -function resolveVisionProtocol(config: ProviderConfig): ProviderProtocol { - return config.provider ?? "openai_compatible"; -} - export async function analyzeImageDataUrl( config: ProviderConfig, imageDataUrl: string, prompt: string, ): Promise { - const protocol = resolveVisionProtocol(config); - const baseURL = normalizeBaseUrl(config.baseUrl, protocol); - - let model; - switch (protocol) { - case "anthropic": - model = createAnthropic({ apiKey: config.apiKey, baseURL })(config.model); - break; - case "google": - model = createGoogleGenerativeAI({ apiKey: config.apiKey, baseURL })(config.model); - break; - case "openai_compatible": - case "openai": - default: - model = createOpenAI({ apiKey: config.apiKey, baseURL }).chat(config.model); - break; - } + const protocol = resolveProtocol(config); + const model = createLanguageModel(config, protocol); const messages: ModelMessage[] = [ {