"use client" import { useState, useEffect, useRef } from "react" import { Button } from "@/components/ui/button" import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card" import { Badge } from "@/components/ui/badge" import { Progress } from "@/components/ui/progress" import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs" import { Slider } from "@/components/ui/slider" import { Label } from "@/components/ui/label" import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select" import { Play, Pause, Square, Zap, Info } from "lucide-react" interface ModelTrainingStepProps { onNext?: () => void } interface TrainingConfig { learningRate: number batchSize: number epochs: number optimizer: string architecture: string } interface TrainingMetrics { epoch: number loss: number accuracy: number valLoss: number valAccuracy: number learningRate: number } interface TrainingState { isTraining: boolean isPaused: boolean currentEpoch: number totalEpochs: number progress: number metrics: TrainingMetrics[] phase: "idle" | "initializing" | "training" | "validating" | "completed" } const architectures = [ { value: "cnn", label: "CNN (卷积神经网络)", description: "适合图像特征提取" }, { value: "resnet", label: "ResNet", description: "深度残差网络,防止梯度消失" }, { value: "vgg", label: "VGG", description: "经典深度卷积网络" }, { value: "gan", label: "GAN", description: "生成对抗网络,用于图像生成" }, ] const optimizers = [ { value: "adam", label: "Adam", description: "自适应学习率优化器" }, { value: "sgd", label: "SGD", description: "随机梯度下降" }, { value: "rmsprop", label: "RMSprop", description: "均方根传播算法" }, ] export default function ModelTrainingStep({ onNext }: ModelTrainingStepProps) { const [config, setConfig] = useState({ learningRate: 0.001, batchSize: 32, epochs: 50, optimizer: "adam", architecture: "cnn", }) const [trainingState, setTrainingState] = useState({ isTraining: false, isPaused: false, currentEpoch: 0, totalEpochs: 0, progress: 0, metrics: [], phase: "idle", }) const intervalRef = useRef(null) const generateMetrics = (epoch: number): TrainingMetrics => { // 模拟真实的训练曲线 const progress = epoch / config.epochs const baseLoss = 2.5 * Math.exp(-progress * 3) + 0.1 const baseAccuracy = 0.95 * (1 - Math.exp(-progress * 4)) + 0.05 // 添加一些随机波动 const lossNoise = (Math.random() - 0.5) * 0.1 const accNoise = (Math.random() - 0.5) * 0.05 return { epoch, loss: Math.max(0.05, baseLoss + lossNoise), accuracy: Math.min(0.98, Math.max(0.05, baseAccuracy + accNoise)), valLoss: Math.max(0.08, baseLoss * 1.1 + lossNoise), valAccuracy: Math.min(0.95, Math.max(0.05, baseAccuracy * 0.95 + accNoise)), learningRate: config.learningRate * Math.pow(0.95, Math.floor(epoch / 10)), } } const startTraining = () => { setTrainingState((prev) => ({ ...prev, isTraining: true, isPaused: false, currentEpoch: 0, totalEpochs: config.epochs, progress: 0, metrics: [], phase: "initializing", })) // 模拟训练过程 let currentEpoch = 0 intervalRef.current = setInterval(() => { if (currentEpoch >= config.epochs) { setTrainingState((prev) => ({ ...prev, isTraining: false, phase: "completed", })) if (intervalRef.current) clearInterval(intervalRef.current) return } const metrics = generateMetrics(currentEpoch + 1) setTrainingState((prev) => ({ ...prev, currentEpoch: currentEpoch + 1, progress: ((currentEpoch + 1) / config.epochs) * 100, metrics: [...prev.metrics, metrics], phase: currentEpoch % 5 === 4 ? "validating" : "training", })) currentEpoch++ }, 200) // 每200ms一个epoch,加快演示速度 } const pauseTraining = () => { setTrainingState((prev) => ({ ...prev, isPaused: true })) if (intervalRef.current) clearInterval(intervalRef.current) } const stopTraining = () => { setTrainingState({ isTraining: false, isPaused: false, currentEpoch: 0, totalEpochs: 0, progress: 0, metrics: [], phase: "idle", }) if (intervalRef.current) clearInterval(intervalRef.current) } const updateConfig = (key: keyof TrainingConfig, value: any) => { if (!trainingState.isTraining) { setConfig((prev) => ({ ...prev, [key]: value })) } } const currentMetrics = trainingState.metrics[trainingState.metrics.length - 1] const maxLoss = Math.max(...trainingState.metrics.map((m) => Math.max(m.loss, m.valLoss)), 2.5) const maxAccuracy = 1.0 useEffect(() => { return () => { if (intervalRef.current) clearInterval(intervalRef.current) } }, []) const handleNext = () => { if (onNext && trainingState.phase === "completed") { onNext() } } return (
{/* Introduction */}
什么是模型训练?

模型训练是AI学习的核心过程。就像学生通过大量练习来掌握知识一样,AI模型通过反复处理训练数据来学习图像的特征和模式。 在这个过程中,模型会不断调整内部参数,逐渐提高对图像的理解能力,最终能够生成新的艺术作品。

{/* Training Configuration */} 训练配置 调整训练参数,观察对模型性能的影响 {/* Architecture Selection */}
{/* Optimizer Selection */}
{/* Learning Rate */}
updateConfig("learningRate", value)} min={0.0001} max={0.01} step={0.0001} disabled={trainingState.isTraining} className="w-full" />

控制模型参数更新的步长

{/* Batch Size */}
updateConfig("batchSize", value)} min={8} max={128} step={8} disabled={trainingState.isTraining} className="w-full" />

每次训练处理的图像数量

{/* Epochs */}
updateConfig("epochs", value)} min={10} max={200} step={10} disabled={trainingState.isTraining} className="w-full" />

完整遍历训练数据的次数

{/* Training Controls */}
{!trainingState.isTraining ? ( ) : ( <> )}
{/* Training Progress and Metrics */}
{/* Training Status */}
训练状态 {trainingState.phase === "idle" ? "待机中" : trainingState.phase === "initializing" ? "初始化中" : trainingState.phase === "training" ? "训练中" : trainingState.phase === "validating" ? "验证中" : "训练完成"}
{/* Overall Progress */}
总体进度 {trainingState.currentEpoch}/{trainingState.totalEpochs} 轮
{/* Current Metrics */} {currentMetrics && (
{currentMetrics.loss.toFixed(4)}
训练损失
{(currentMetrics.accuracy * 100).toFixed(1)}%
训练准确率
{currentMetrics.valLoss.toFixed(4)}
验证损失
{(currentMetrics.valAccuracy * 100).toFixed(1)}%
验证准确率
)}
{/* Training Charts */} 训练曲线 实时监控模型训练过程中的性能指标 损失函数 准确率
{/* Grid lines */} {/* Loss curves */} {trainingState.metrics.length > 1 && ( <> {/* Training loss */} `${(i / Math.max(trainingState.metrics.length - 1, 1)) * 400},${200 - (m.loss / maxLoss) * 180}`, ) .join(" ")} /> {/* Validation loss */} `${(i / Math.max(trainingState.metrics.length - 1, 1)) * 400},${200 - (m.valLoss / maxLoss) * 180}`, ) .join(" ")} /> )}
训练损失
验证损失
{/* Accuracy curves */} {trainingState.metrics.length > 1 && ( <> {/* Training accuracy */} `${(i / Math.max(trainingState.metrics.length - 1, 1)) * 400},${200 - m.accuracy * 180}`, ) .join(" ")} /> {/* Validation accuracy */} `${(i / Math.max(trainingState.metrics.length - 1, 1)) * 400},${200 - m.valAccuracy * 180}`, ) .join(" ")} /> )}
训练准确率
验证准确率
{/* Next Step Button */}
) }