diff --git a/client/package.json b/client/package.json index 30b0b2b..08a7785 100644 --- a/client/package.json +++ b/client/package.json @@ -14,7 +14,8 @@ "react-dom": "^19.2.5", "react-konva": "^18.2.10", "konva": "^9.3.18", - "use-image": "^1.1.1" + "use-image": "^1.1.1", + "@xenova/transformers": "^2.17.2" }, "devDependencies": { "@eslint/js": "^9.39.4", diff --git a/client/src/components/properties/PropertiesPanel.jsx b/client/src/components/properties/PropertiesPanel.jsx index 7b9140e..c9b46fb 100644 --- a/client/src/components/properties/PropertiesPanel.jsx +++ b/client/src/components/properties/PropertiesPanel.jsx @@ -1,3 +1,5 @@ +import { BackgroundRemovalButton } from '../sidebar/BackgroundRemovalButton'; + export function PropertiesPanel({ selectedElement, onUpdate, onDelete }) { if (!selectedElement) { return ( @@ -101,6 +103,13 @@ export function PropertiesPanel({ selectedElement, onUpdate, onDelete }) { /> + {selectedElement.type === 'image' && ( + + )} + diff --git a/client/src/components/sidebar/BackgroundRemovalButton.jsx b/client/src/components/sidebar/BackgroundRemovalButton.jsx new file mode 100644 index 0000000..0f93cf7 --- /dev/null +++ b/client/src/components/sidebar/BackgroundRemovalButton.jsx @@ -0,0 +1,47 @@ +import { useBackgroundRemoval } from '../../hooks/useBackgroundRemoval'; + +export function BackgroundRemovalButton({ selectedElement, onUpdate }) { + const { loading, progress, hasModel, loadModel, removeBackground } = useBackgroundRemoval(); + + const handleRemoveBackground = async () => { + if (!selectedElement || selectedElement.type !== 'image') return; + + if (!hasModel) { + const loaded = await loadModel(); + if (!loaded) return; + } + + const resultUrl = await removeBackground(selectedElement.src); + if (resultUrl) { + onUpdate(selectedElement.id, { src: resultUrl }); + } + }; + + if (!selectedElement || selectedElement.type !== 'image') { + return null; + } + + return ( +
+ + {!hasModel && ( +

+ First use requires downloading ~170MB model. Subsequent uses are cached. +

+ )} +
+ ); +} diff --git a/client/src/hooks/useBackgroundRemoval.js b/client/src/hooks/useBackgroundRemoval.js new file mode 100644 index 0000000..5d2cd77 --- /dev/null +++ b/client/src/hooks/useBackgroundRemoval.js @@ -0,0 +1,120 @@ +import { useState, useCallback } from 'react'; +import { env, AutoModel, AutoProcessor, RawImage } from '@xenova/transformers'; + +// Use local models only +env.allowLocalModels = true; +env.useBrowserCache = true; + +export function useBackgroundRemoval() { + const [loading, setLoading] = useState(false); + const [progress, setProgress] = useState(0); + const [model, setModel] = useState(null); + const [processor, setProcessor] = useState(null); + + const loadModel = useCallback(async () => { + if (model && processor) return true; + + setLoading(true); + setProgress(0); + + try { + const loadedModel = await AutoModel.from_pretrained('Xenova/rmbg-1.4', { + progress_callback: (data) => { + if (data.status === 'progress') { + setProgress(Math.round(data.progress)); + } + }, + local_model_path: '/models/rmbg-1.4', + }); + + const loadedProcessor = await AutoProcessor.from_pretrained('Xenova/rmbg-1.4', { + local_model_path: '/models/rmbg-1.4', + }); + + setModel(loadedModel); + setProcessor(loadedProcessor); + setLoading(false); + return true; + } catch (error) { + console.error('Failed to load background removal model:', error); + setLoading(false); + return false; + } + }, [model, processor]); + + const removeBackground = useCallback(async (imageSrc) => { + if (!model || !processor) { + const loaded = await loadModel(); + if (!loaded) return null; + } + + setLoading(true); + + try { + // Load the image + const img = new Image(); + img.crossOrigin = 'anonymous'; + img.src = imageSrc; + await new Promise((resolve, reject) => { + img.onload = resolve; + img.onerror = reject; + }); + + // Process image through the model + const inputs = await processor(img); + const { pixel_values } = inputs; + + // Run inference + const { output } = await model({ pixel_values }); + + // Get the mask + const maskData = await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize( + img.width, + img.height + ); + + // Create canvas to apply mask + const canvas = document.createElement('canvas'); + canvas.width = img.width; + canvas.height = img.height; + const ctx = canvas.getContext('2d'); + + // Draw original image + ctx.drawImage(img, 0, 0); + + // Get image data + const imageData = ctx.getImageData(0, 0, img.width, img.height); + const data = imageData.data; + const maskPixels = maskData.data; + + // Apply alpha mask + for (let i = 0; i < maskPixels.length; i++) { + const alpha = maskPixels[i]; + data[i * 4 + 3] = alpha; // Set alpha channel + } + + ctx.putImageData(imageData, 0, 0); + + // Convert to blob URL + const blob = await new Promise((resolve) => { + canvas.toBlob(resolve, 'image/png'); + }); + + const url = URL.createObjectURL(blob); + setLoading(false); + return url; + } catch (error) { + console.error('Background removal failed:', error); + setLoading(false); + return null; + } + }, [model, processor, loadModel]); + + return { + loading, + progress, + hasModel: !!model, + loadModel, + removeBackground, + }; +} diff --git a/client/src/index.css b/client/src/index.css index 1790f68..e0f5827 100644 --- a/client/src/index.css +++ b/client/src/index.css @@ -490,6 +490,54 @@ input, textarea, select { color: white; } +.bg-removal-container { + margin-top: 1rem; + padding-top: 1rem; + border-top: 1px solid var(--border); +} + +.bg-removal-btn { + width: 100%; + padding: 0.75rem; + background: linear-gradient(135deg, #8b5cf6, #ec4899); + color: white; + border: none; + border-radius: var(--radius-md); + font-weight: 500; + cursor: pointer; + transition: all 0.2s; + display: flex; + align-items: center; + justify-content: center; + gap: 0.5rem; +} + +.bg-removal-btn:hover:not(:disabled) { + transform: translateY(-1px); + box-shadow: var(--shadow-md); +} + +.bg-removal-btn:disabled { + opacity: 0.7; + cursor: not-allowed; +} + +.bg-removal-hint { + font-size: 0.7rem; + color: var(--text-muted); + margin: 0.5rem 0 0 0; + line-height: 1.4; +} + +.spinner-small { + width: 16px; + height: 16px; + border: 2px solid rgba(255, 255, 255, 0.3); + border-top-color: white; + border-radius: 50%; + animation: spin 1s linear infinite; +} + /* Responsive */ @media (max-width: 900px) { .app-layout {