Skip to content

feat: added web worker support #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 113 additions & 4 deletions examples/chat-demo/src/components/ChatInterface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ const Button = styled.button`
}
`;

const TestWorkerButton = styled(Button)`
font-size: 12px;
padding: 4px 8px;
`;

const Spinner = styled.div`
width: 40px;
height: 40px;
Expand Down Expand Up @@ -339,6 +344,28 @@ const ThinkingDropdown = styled.div<{ isOpen: boolean }>`
}
`;

const WorkerStatus = styled.div<{ active: boolean }>`
position: fixed;
bottom: 20px;
right: 20px;
padding: 8px 12px;
border-radius: 4px;
background: ${props => props.active ? '#1a392c' : '#2d2d2d'};
color: ${props => props.active ? '#4ade80' : '#a0a0a0'};
font-size: 12px;
display: flex;
align-items: center;
gap: 6px;

&::before {
content: '';
width: 8px;
height: 8px;
border-radius: 50%;
background: ${props => props.active ? '#4ade80' : '#a0a0a0'};
}
`;

interface ChatInterfaceProps {
children?: (props: {
stats: {
Expand Down Expand Up @@ -378,8 +405,9 @@ interface LoadingStats {
}

export default function ChatInterface({ children }: ChatInterfaceProps) {
const [browserAI] = useState(new BrowserAI());
const [browserAI] = useState(() => new BrowserAI());
const [selectedModel, setSelectedModel] = useState('smollm2-135m-instruct');
const [useWebWorker, setUseWebWorker] = useState(false);
const [loading, setLoading] = useState(false);
const [messages, setMessages] = useState<Array<{ text: string; isUser: boolean }>>([]);
const [input, setInput] = useState('');
Expand Down Expand Up @@ -423,7 +451,11 @@ export default function ChatInterface({ children }: ChatInterfaceProps) {


const loadModel = async () => {
console.log(`[BrowserAI] Starting to load model: ${selectedModel}`);
console.log(`[BrowserAI] Starting to load model with worker: ${useWebWorker}`);

// Add performance markers
performance.mark('modelLoadStart');

setLoading(true);
setLoadError(null);
const startTime = performance.now();
Expand All @@ -437,8 +469,9 @@ export default function ChatInterface({ children }: ChatInterfaceProps) {

try {
await browserAI.loadModel(selectedModel, {
useWorker: useWebWorker,
onProgress: (progress: any) => {
console.log(`[BrowserAI] Loading progress:`, progress);
console.log(`[${useWebWorker ? 'Worker' : 'Main'}] Progress:`, progress);
const currentTime = performance.now();
const elapsedTime = (currentTime - startTime) / 1000; // in seconds
const progressPercent = progress.progress;
Expand Down Expand Up @@ -483,7 +516,7 @@ export default function ChatInterface({ children }: ChatInterfaceProps) {
});

const loadTime = performance.now() - startTime;
console.log(`[BrowserAI] Model loaded successfully in ${loadTime.toFixed(0)}ms`);
console.log(`[BrowserAI] Model loaded successfully in ${loadTime.toFixed(0)}ms using ${useWebWorker ? 'Web Worker' : 'Main Thread'}`);
const memoryAfter = (performance as any).memory?.usedJSHeapSize;
const memoryIncrease = memoryAfter - memoryBefore;

Expand All @@ -509,6 +542,9 @@ export default function ChatInterface({ children }: ChatInterfaceProps) {
progress: 0,
estimatedTimeRemaining: null
});

performance.mark('modelLoadEnd');
performance.measure('Model Load Time', 'modelLoadStart', 'modelLoadEnd');
};

const handleModelChange = (newModel: string) => {
Expand All @@ -519,6 +555,12 @@ export default function ChatInterface({ children }: ChatInterfaceProps) {
const handleSend = async () => {
if (!input.trim() || !modelLoaded) return;

console.log(`[BrowserAI] Starting generation using ${
useWebWorker ? 'Web Worker' : 'Main Thread'
}`);

performance.mark('generationStart');

console.log(`[BrowserAI] Starting text generation with input length: ${input.length}`);
const userMessage = { text: input, isUser: true };
setMessages(prev => [...prev, userMessage]);
Expand Down Expand Up @@ -575,6 +617,13 @@ export default function ChatInterface({ children }: ChatInterfaceProps) {
};
});

performance.mark('generationEnd');
performance.measure('Generation Time', 'generationStart', 'generationEnd');

console.log(`[BrowserAI] Generation completed in ${
performance.getEntriesByName('Generation Time')[0].duration.toFixed(0)
}ms using ${useWebWorker ? 'Web Worker' : 'Main Thread'}`);

} catch (err) {
const error = err as Error;
console.error('[BrowserAI] Error generating text:', {
Expand Down Expand Up @@ -690,6 +739,55 @@ export default function ChatInterface({ children }: ChatInterfaceProps) {
<option value="gemma-2b-it">Gemma 2B Instruct (1.4GB)</option>
<option value="tinyllama-1.1b-chat-v0.4">TinyLlama 1.1B Chat (670MB)</option>
</ModelSelect>

<div style={{
display: 'flex',
alignItems: 'center',
gap: '8px'
}}>
<input
type="checkbox"
id="worker-toggle"
checked={useWebWorker}
onChange={e => setUseWebWorker(e.target.checked)}
disabled={loading || modelLoaded}
/>
<label
htmlFor="worker-toggle"
style={{ color: '#a0a0a0', fontSize: '14px' }}
>
Use Web Worker
</label>
</div>

<TestWorkerButton
onClick={async () => {
if (!modelLoaded) return;

// Start a UI-blocking operation
const startTime = performance.now();

// Generate text while also updating UI
let dots = '';
const updateInterval = setInterval(() => {
dots = dots.length >= 3 ? '' : dots + '.';
setInput(`Testing worker${dots}`);
}, 100);

try {
await browserAI.generateText('Generate a long story about a cat.');
clearInterval(updateInterval);
setInput(`Test completed in ${(performance.now() - startTime).toFixed(0)}ms`);
} catch (err) {
clearInterval(updateInterval);
setInput('Test failed');
}
}}
disabled={!modelLoaded}
>
Test Worker
</TestWorkerButton>

<Button
onClick={loadModel}
disabled={loading || modelLoaded}
Expand Down Expand Up @@ -880,8 +978,19 @@ export default function ChatInterface({ children }: ChatInterfaceProps) {
</div>
</StatItem>

<StatItem>
<h3>Processing Mode</h3>
<div style={{ color: '#fff', fontSize: '14px' }}>
{useWebWorker ? 'Web Worker (Background Thread)' : 'Main Thread'}
</div>
</StatItem>

</Sidebar>
</MainContent>

<WorkerStatus active={useWebWorker && modelLoaded}>
{useWebWorker ? 'Web Worker Active' : 'Main Thread'}
</WorkerStatus>
</Layout>
);
}
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@browserai/browserai",
"version": "1.0.27",
"version": "1.0.28",
"private": false,
"description": "A library for running AI models directly in the browser",
"main": "dist/index.js",
Expand Down
8 changes: 8 additions & 0 deletions src/core/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,12 @@ export class BrowserAI {

throw new Error('Current engine does not support multimodal generation');
}

dispose() {
if (this.engine instanceof MLCEngineWrapper) {
this.engine.dispose();
}
this.engine = null;
this.currentModel = null;
}
}
105 changes: 96 additions & 9 deletions src/engines/mlc-engine-wrapper.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,74 @@
// src/engines/mlc-engine-wrapper.ts
import { CreateMLCEngine, MLCEngineInterface, AppConfig, modelLibURLPrefix, modelVersion, prebuiltAppConfig } from '@mlc-ai/web-llm';
import {
CreateMLCEngine,
MLCEngineInterface,
AppConfig,
modelLibURLPrefix,
modelVersion,
prebuiltAppConfig,
CreateWebWorkerMLCEngine
} from '@mlc-ai/web-llm';
import { ModelConfig } from '../config/models/types';

interface MLCLoadModelOptions {
useWorker?: boolean;
onProgress?: (progress: any) => void;
quantization?: string;
[key: string]: any;
}

export class MLCEngineWrapper {
private mlcEngine: MLCEngineInterface | null = null;
private appConfig: AppConfig | null = null;
private worker: Worker | null = null;

constructor() {
this.mlcEngine = null;
}

async loadModel(modelConfig: ModelConfig, options: any = {}) {
async loadModel(modelConfig: ModelConfig, options: MLCLoadModelOptions = {}) {
try {
// Clean up any existing worker
if (this.worker) {
this.worker.terminate();
this.worker = null;
}

// Create new worker if requested
if (options.useWorker) {
console.log('[MLCEngine] Creating new worker');

this.worker = new Worker(
new URL('../workers/mlc.worker.ts', import.meta.url),
{ type: 'module' }
);

// Add error handling for worker
this.worker.onerror = (error) => {
console.error('[MLCEngine] Worker error:', error);
throw new Error(`Worker error: ${error.message}`);
};

this.worker.onmessageerror = (error) => {
console.error('[MLCEngine] Worker message error:', error);
};

// Listen for messages from worker
this.worker.onmessage = (msg) => {
console.log('[MLCEngine] Received worker message:', msg.data);
if (msg.data.type === 'error') {
throw new Error(`Worker error: ${msg.data.error}`);
}
};

console.log('[MLCEngine] Worker created successfully');
}

const quantization = options.quantization || modelConfig.defaultQuantization;
const modelIdentifier = modelConfig.repo.replace('{quantization}', quantization).split('/')[1];

console.log('[MLCEngine] Loading model:', modelIdentifier, 'with worker:', !!this.worker);

if (modelConfig.modelLibrary) {
this.appConfig = {
model_list: [
Expand All @@ -29,14 +85,37 @@ export class MLCEngineWrapper {
else {
this.appConfig = prebuiltAppConfig;
}
// console.log(this.appConfig);
this.mlcEngine = await CreateMLCEngine(modelIdentifier, {
initProgressCallback: options.onProgress, // Pass progress callback
appConfig: this.appConfig,
...options, // Pass other options
});

if (this.worker) {
console.log('[MLCEngine] Creating web worker engine');
this.mlcEngine = await CreateWebWorkerMLCEngine(
this.worker,
modelIdentifier,
{
initProgressCallback: (progress: any) => {
console.log('[MLCEngine] Loading progress:', progress);
options.onProgress?.(progress);
},
appConfig: this.appConfig,
...options,
}
);
console.log('[MLCEngine] Web worker engine created successfully');
} else {
this.mlcEngine = await CreateMLCEngine(modelIdentifier, {
initProgressCallback: options.onProgress,
appConfig: this.appConfig,
...options,
});
}
} catch (error) {
console.error('Error loading MLC model:', error);
// Clean up worker if initialization failed
if (this.worker) {
console.error('[MLCEngine] Error with worker, cleaning up');
this.worker.terminate();
this.worker = null;
}
console.error('[MLCEngine] Error loading model:', error);
const message = error instanceof Error ? error.message : String(error);
throw new Error(`Failed to load MLC model "${modelConfig}": ${message}`);
}
Expand Down Expand Up @@ -103,4 +182,12 @@ export class MLCEngineWrapper {
const result = await this.mlcEngine.embeddings.create({ input, ...options });
return result.data[0].embedding;
}

dispose() {
if (this.worker) {
this.worker.terminate();
this.worker = null;
}
this.mlcEngine = null;
}
}
31 changes: 31 additions & 0 deletions src/workers/mlc.worker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm";

console.log('[Worker] Initializing MLC Web Worker');

const handler = new WebWorkerMLCEngineHandler();

// Add error handling and logging
self.onerror = (error) => {
console.error('[Worker] Error:', error);
};

self.onmessageerror = (error) => {
console.error('[Worker] Message Error:', error);
};

self.onmessage = (msg) => {
console.log('[Worker] Received message:', msg.data);
try {
handler.onmessage(msg);
} catch (error) {
console.error('[Worker] Handler error:', error);
// Notify main thread of error
self.postMessage({
type: 'error',
error: error.message
});
}
};

// Log when handler is ready
console.log('[Worker] MLC Web Worker initialized');