- Added @xenova/transformers dependency - useBackgroundRemoval hook with RMBG-1.4 model - Client-side background removal with progress indicator - Background removal button in properties panel (image elements only) - ~170MB model cached after first download Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
121 lines
3.2 KiB
JavaScript
121 lines
3.2 KiB
JavaScript
import { useState, useCallback } from 'react';
|
|
import { env, AutoModel, AutoProcessor, RawImage } from '@xenova/transformers';
|
|
|
|
// Use local models only
|
|
env.allowLocalModels = true;
|
|
env.useBrowserCache = true;
|
|
|
|
export function useBackgroundRemoval() {
|
|
const [loading, setLoading] = useState(false);
|
|
const [progress, setProgress] = useState(0);
|
|
const [model, setModel] = useState(null);
|
|
const [processor, setProcessor] = useState(null);
|
|
|
|
const loadModel = useCallback(async () => {
|
|
if (model && processor) return true;
|
|
|
|
setLoading(true);
|
|
setProgress(0);
|
|
|
|
try {
|
|
const loadedModel = await AutoModel.from_pretrained('Xenova/rmbg-1.4', {
|
|
progress_callback: (data) => {
|
|
if (data.status === 'progress') {
|
|
setProgress(Math.round(data.progress));
|
|
}
|
|
},
|
|
local_model_path: '/models/rmbg-1.4',
|
|
});
|
|
|
|
const loadedProcessor = await AutoProcessor.from_pretrained('Xenova/rmbg-1.4', {
|
|
local_model_path: '/models/rmbg-1.4',
|
|
});
|
|
|
|
setModel(loadedModel);
|
|
setProcessor(loadedProcessor);
|
|
setLoading(false);
|
|
return true;
|
|
} catch (error) {
|
|
console.error('Failed to load background removal model:', error);
|
|
setLoading(false);
|
|
return false;
|
|
}
|
|
}, [model, processor]);
|
|
|
|
const removeBackground = useCallback(async (imageSrc) => {
|
|
if (!model || !processor) {
|
|
const loaded = await loadModel();
|
|
if (!loaded) return null;
|
|
}
|
|
|
|
setLoading(true);
|
|
|
|
try {
|
|
// Load the image
|
|
const img = new Image();
|
|
img.crossOrigin = 'anonymous';
|
|
img.src = imageSrc;
|
|
await new Promise((resolve, reject) => {
|
|
img.onload = resolve;
|
|
img.onerror = reject;
|
|
});
|
|
|
|
// Process image through the model
|
|
const inputs = await processor(img);
|
|
const { pixel_values } = inputs;
|
|
|
|
// Run inference
|
|
const { output } = await model({ pixel_values });
|
|
|
|
// Get the mask
|
|
const maskData = await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(
|
|
img.width,
|
|
img.height
|
|
);
|
|
|
|
// Create canvas to apply mask
|
|
const canvas = document.createElement('canvas');
|
|
canvas.width = img.width;
|
|
canvas.height = img.height;
|
|
const ctx = canvas.getContext('2d');
|
|
|
|
// Draw original image
|
|
ctx.drawImage(img, 0, 0);
|
|
|
|
// Get image data
|
|
const imageData = ctx.getImageData(0, 0, img.width, img.height);
|
|
const data = imageData.data;
|
|
const maskPixels = maskData.data;
|
|
|
|
// Apply alpha mask
|
|
for (let i = 0; i < maskPixels.length; i++) {
|
|
const alpha = maskPixels[i];
|
|
data[i * 4 + 3] = alpha; // Set alpha channel
|
|
}
|
|
|
|
ctx.putImageData(imageData, 0, 0);
|
|
|
|
// Convert to blob URL
|
|
const blob = await new Promise((resolve) => {
|
|
canvas.toBlob(resolve, 'image/png');
|
|
});
|
|
|
|
const url = URL.createObjectURL(blob);
|
|
setLoading(false);
|
|
return url;
|
|
} catch (error) {
|
|
console.error('Background removal failed:', error);
|
|
setLoading(false);
|
|
return null;
|
|
}
|
|
}, [model, processor, loadModel]);
|
|
|
|
return {
|
|
loading,
|
|
progress,
|
|
hasModel: !!model,
|
|
loadModel,
|
|
removeBackground,
|
|
};
|
|
}
|