Files
apparel-designer/client/src/hooks/useBackgroundRemoval.js
Khalid A 4a735e2f2e Phase 4: Background Removal (Transformers.js)
- Added @xenova/transformers dependency
- useBackgroundRemoval hook with RMBG-1.4 model
- Client-side background removal with progress indicator
- Background removal button in properties panel (image elements only)
- ~170MB model cached after first download

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-21 01:19:03 -05:00

121 lines
3.2 KiB
JavaScript

import { useState, useCallback } from 'react';
import { env, AutoModel, AutoProcessor, RawImage } from '@xenova/transformers';
// Use local models only
env.allowLocalModels = true;
env.useBrowserCache = true;
export function useBackgroundRemoval() {
const [loading, setLoading] = useState(false);
const [progress, setProgress] = useState(0);
const [model, setModel] = useState(null);
const [processor, setProcessor] = useState(null);
const loadModel = useCallback(async () => {
if (model && processor) return true;
setLoading(true);
setProgress(0);
try {
const loadedModel = await AutoModel.from_pretrained('Xenova/rmbg-1.4', {
progress_callback: (data) => {
if (data.status === 'progress') {
setProgress(Math.round(data.progress));
}
},
local_model_path: '/models/rmbg-1.4',
});
const loadedProcessor = await AutoProcessor.from_pretrained('Xenova/rmbg-1.4', {
local_model_path: '/models/rmbg-1.4',
});
setModel(loadedModel);
setProcessor(loadedProcessor);
setLoading(false);
return true;
} catch (error) {
console.error('Failed to load background removal model:', error);
setLoading(false);
return false;
}
}, [model, processor]);
const removeBackground = useCallback(async (imageSrc) => {
if (!model || !processor) {
const loaded = await loadModel();
if (!loaded) return null;
}
setLoading(true);
try {
// Load the image
const img = new Image();
img.crossOrigin = 'anonymous';
img.src = imageSrc;
await new Promise((resolve, reject) => {
img.onload = resolve;
img.onerror = reject;
});
// Process image through the model
const inputs = await processor(img);
const { pixel_values } = inputs;
// Run inference
const { output } = await model({ pixel_values });
// Get the mask
const maskData = await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(
img.width,
img.height
);
// Create canvas to apply mask
const canvas = document.createElement('canvas');
canvas.width = img.width;
canvas.height = img.height;
const ctx = canvas.getContext('2d');
// Draw original image
ctx.drawImage(img, 0, 0);
// Get image data
const imageData = ctx.getImageData(0, 0, img.width, img.height);
const data = imageData.data;
const maskPixels = maskData.data;
// Apply alpha mask
for (let i = 0; i < maskPixels.length; i++) {
const alpha = maskPixels[i];
data[i * 4 + 3] = alpha; // Set alpha channel
}
ctx.putImageData(imageData, 0, 0);
// Convert to blob URL
const blob = await new Promise((resolve) => {
canvas.toBlob(resolve, 'image/png');
});
const url = URL.createObjectURL(blob);
setLoading(false);
return url;
} catch (error) {
console.error('Background removal failed:', error);
setLoading(false);
return null;
}
}, [model, processor, loadModel]);
return {
loading,
progress,
hasModel: !!model,
loadModel,
removeBackground,
};
}