import { useState, useCallback } from 'react'; import { env, AutoModel, AutoProcessor, RawImage } from '@xenova/transformers'; // Use local models only env.allowLocalModels = true; env.useBrowserCache = true; export function useBackgroundRemoval() { const [loading, setLoading] = useState(false); const [progress, setProgress] = useState(0); const [model, setModel] = useState(null); const [processor, setProcessor] = useState(null); const loadModel = useCallback(async () => { if (model && processor) return true; setLoading(true); setProgress(0); try { const loadedModel = await AutoModel.from_pretrained('Xenova/rmbg-1.4', { progress_callback: (data) => { if (data.status === 'progress') { setProgress(Math.round(data.progress)); } }, local_model_path: '/models/rmbg-1.4', }); const loadedProcessor = await AutoProcessor.from_pretrained('Xenova/rmbg-1.4', { local_model_path: '/models/rmbg-1.4', }); setModel(loadedModel); setProcessor(loadedProcessor); setLoading(false); return true; } catch (error) { console.error('Failed to load background removal model:', error); setLoading(false); return false; } }, [model, processor]); const removeBackground = useCallback(async (imageSrc) => { if (!model || !processor) { const loaded = await loadModel(); if (!loaded) return null; } setLoading(true); try { // Load the image const img = new Image(); img.crossOrigin = 'anonymous'; img.src = imageSrc; await new Promise((resolve, reject) => { img.onload = resolve; img.onerror = reject; }); // Process image through the model const inputs = await processor(img); const { pixel_values } = inputs; // Run inference const { output } = await model({ pixel_values }); // Get the mask const maskData = await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize( img.width, img.height ); // Create canvas to apply mask const canvas = document.createElement('canvas'); canvas.width = img.width; canvas.height = img.height; const ctx = canvas.getContext('2d'); // Draw original image ctx.drawImage(img, 0, 0); // Get image data const imageData = ctx.getImageData(0, 0, img.width, img.height); const data = imageData.data; const maskPixels = maskData.data; // Apply alpha mask for (let i = 0; i < maskPixels.length; i++) { const alpha = maskPixels[i]; data[i * 4 + 3] = alpha; // Set alpha channel } ctx.putImageData(imageData, 0, 0); // Convert to blob URL const blob = await new Promise((resolve) => { canvas.toBlob(resolve, 'image/png'); }); const url = URL.createObjectURL(blob); setLoading(false); return url; } catch (error) { console.error('Background removal failed:', error); setLoading(false); return null; } }, [model, processor, loadModel]); return { loading, progress, hasModel: !!model, loadModel, removeBackground, }; }