||
- "use client"
- import { useState, useEffect, useRef } from "react"
- import { Button } from "@/components/ui/button"
- import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"
- import { Badge } from "@/components/ui/badge"
- import { Progress } from "@/components/ui/progress"
- import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"
- import { Slider } from "@/components/ui/slider"
- import { Label } from "@/components/ui/label"
- import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"
- import { Play, Pause, Square, Zap, Info } from "lucide-react"
- interface ModelTrainingStepProps {
- onNext?: () => void
- }
- interface TrainingConfig {
- learningRate: number
- batchSize: number
- epochs: number
- optimizer: string
- architecture: string
- }
- interface TrainingMetrics {
- epoch: number
- loss: number
- accuracy: number
- valLoss: number
- valAccuracy: number
- learningRate: number
- }
- interface TrainingState {
- isTraining: boolean
- isPaused: boolean
- currentEpoch: number
- totalEpochs: number
- progress: number
- metrics: TrainingMetrics[]
- phase: "idle" | "initializing" | "training" | "validating" | "completed"
- }
- const architectures = [
- { value: "cnn", label: "CNN (卷积神经网络)", description: "适合图像特征提取" },
- { value: "resnet", label: "ResNet", description: "深度残差网络,防止梯度消失" },
- { value: "vgg", label: "VGG", description: "经典深度卷积网络" },
- { value: "gan", label: "GAN", description: "生成对抗网络,用于图像生成" },
- ]
- const optimizers = [
- { value: "adam", label: "Adam", description: "自适应学习率优化器" },
- { value: "sgd", label: "SGD", description: "随机梯度下降" },
- { value: "rmsprop", label: "RMSprop", description: "均方根传播算法" },
- ]
- export default function ModelTrainingStep({ onNext }: ModelTrainingStepProps) {
- const [config, setConfig] = useState<TrainingConfig>({
- learningRate: 0.001,
- batchSize: 32,
- epochs: 50,
- optimizer: "adam",
- architecture: "cnn",
- })
- const [trainingState, setTrainingState] = useState<TrainingState>({
- isTraining: false,
- isPaused: false,
- currentEpoch: 0,
- totalEpochs: 0,
- progress: 0,
- metrics: [],
- phase: "idle",
- })
- const intervalRef = useRef<NodeJS.Timeout | null>(null)
- const generateMetrics = (epoch: number): TrainingMetrics => {
- // 模拟真实的训练曲线
- const progress = epoch / config.epochs
- const baseLoss = 2.5 * Math.exp(-progress * 3) + 0.1
- const baseAccuracy = 0.95 * (1 - Math.exp(-progress * 4)) + 0.05
- // 添加一些随机波动
- const lossNoise = (Math.random() - 0.5) * 0.1
- const accNoise = (Math.random() - 0.5) * 0.05
- return {
- epoch,
- loss: Math.max(0.05, baseLoss + lossNoise),
- accuracy: Math.min(0.98, Math.max(0.05, baseAccuracy + accNoise)),
- valLoss: Math.max(0.08, baseLoss * 1.1 + lossNoise),
- valAccuracy: Math.min(0.95, Math.max(0.05, baseAccuracy * 0.95 + accNoise)),
- learningRate: config.learningRate * Math.pow(0.95, Math.floor(epoch / 10)),
- }
- }
- const startTraining = () => {
- setTrainingState((prev) => ({
- ...prev,
- isTraining: true,
- isPaused: false,
- currentEpoch: 0,
- totalEpochs: config.epochs,
- progress: 0,
- metrics: [],
- phase: "initializing",
- }))
- // 模拟训练过程
- let currentEpoch = 0
- intervalRef.current = setInterval(() => {
- if (currentEpoch >= config.epochs) {
- setTrainingState((prev) => ({
- ...prev,
- isTraining: false,
- phase: "completed",
- }))
- if (intervalRef.current) clearInterval(intervalRef.current)
- return
- }
- const metrics = generateMetrics(currentEpoch + 1)
- setTrainingState((prev) => ({
- ...prev,
- currentEpoch: currentEpoch + 1,
- progress: ((currentEpoch + 1) / config.epochs) * 100,
- metrics: [...prev.metrics, metrics],
- phase: currentEpoch % 5 === 4 ? "validating" : "training",
- }))
- currentEpoch++
- }, 200) // 每200ms一个epoch,加快演示速度
- }
- const pauseTraining = () => {
- setTrainingState((prev) => ({ ...prev, isPaused: true }))
- if (intervalRef.current) clearInterval(intervalRef.current)
- }
- const stopTraining = () => {
- setTrainingState({
- isTraining: false,
- isPaused: false,
- currentEpoch: 0,
- totalEpochs: 0,
- progress: 0,
- metrics: [],
- phase: "idle",
- })
- if (intervalRef.current) clearInterval(intervalRef.current)
- }
- const updateConfig = (key: keyof TrainingConfig, value: any) => {
- if (!trainingState.isTraining) {
- setConfig((prev) => ({ ...prev, [key]: value }))
- }
- }
- const currentMetrics = trainingState.metrics[trainingState.metrics.length - 1]
- const maxLoss = Math.max(...trainingState.metrics.map((m) => Math.max(m.loss, m.valLoss)), 2.5)
- const maxAccuracy = 1.0
- useEffect(() => {
- return () => {
- if (intervalRef.current) clearInterval(intervalRef.current)
- }
- }, [])
- const handleNext = () => {
- if (onNext && trainingState.phase === "completed") {
- onNext()
- }
- }
- return (
- <div className="space-y-6">
- {/* Introduction */}
- <Card className="bg-chart-3/5 border-chart-3/20">
- <CardHeader>
- <div className="flex items-center gap-2">
- <Info className="w-5 h-5 text-chart-3" />
- <CardTitle className="text-lg font-serif">什么是模型训练?</CardTitle>
- </div>
- </CardHeader>
- <CardContent>
- <p className="text-muted-foreground leading-relaxed">
- 模型训练是AI学习的核心过程。就像学生通过大量练习来掌握知识一样,AI模型通过反复处理训练数据来学习图像的特征和模式。
- 在这个过程中,模型会不断调整内部参数,逐渐提高对图像的理解能力,最终能够生成新的艺术作品。
- </p>
- </CardContent>
- </Card>
- <div className="grid lg:grid-cols-3 gap-6">
- {/* Training Configuration */}
- <Card className="lg:col-span-1">
- <CardHeader>
- <CardTitle className="font-serif">训练配置</CardTitle>
- <CardDescription>调整训练参数,观察对模型性能的影响</CardDescription>
- </CardHeader>
- <CardContent className="space-y-6">
- {/* Architecture Selection */}
- <div className="space-y-2">
- <Label className="font-medium">模型架构</Label>
- <Select
- value={config.architecture}
- onValueChange={(value) => updateConfig("architecture", value)}
- disabled={trainingState.isTraining}
- >
- <SelectTrigger>
- <SelectValue />
- </SelectTrigger>
- <SelectContent>
- {architectures.map((arch) => (
- <SelectItem key={arch.value} value={arch.value}>
- <div>
- <div className="font-medium">{arch.label}</div>
- <div className="text-xs text-muted-foreground">{arch.description}</div>
- </div>
- </SelectItem>
- ))}
- </SelectContent>
- </Select>
- </div>
- {/* Optimizer Selection */}
- <div className="space-y-2">
- <Label className="font-medium">优化器</Label>
- <Select
- value={config.optimizer}
- onValueChange={(value) => updateConfig("optimizer", value)}
- disabled={trainingState.isTraining}
- >
- <SelectTrigger>
- <SelectValue />
- </SelectTrigger>
- <SelectContent>
- {optimizers.map((opt) => (
- <SelectItem key={opt.value} value={opt.value}>
- <div>
- <div className="font-medium">{opt.label}</div>
- <div className="text-xs text-muted-foreground">{opt.description}</div>
- </div>
- </SelectItem>
- ))}
- </SelectContent>
- </Select>
- </div>
- {/* Learning Rate */}
- <div className="space-y-3">
- <Label className="font-medium">学习率: {config.learningRate}</Label>
- <Slider
- value={[config.learningRate]}
- onValueChange={([value]) => updateConfig("learningRate", value)}
- min={0.0001}
- max={0.01}
- step={0.0001}
- disabled={trainingState.isTraining}
- className="w-full"
- />
- <p className="text-xs text-muted-foreground">控制模型参数更新的步长</p>
- </div>
- {/* Batch Size */}
- <div className="space-y-3">
- <Label className="font-medium">批次大小: {config.batchSize}</Label>
- <Slider
- value={[config.batchSize]}
- onValueChange={([value]) => updateConfig("batchSize", value)}
- min={8}
- max={128}
- step={8}
- disabled={trainingState.isTraining}
- className="w-full"
- />
- <p className="text-xs text-muted-foreground">每次训练处理的图像数量</p>
- </div>
- {/* Epochs */}
- <div className="space-y-3">
- <Label className="font-medium">训练轮数: {config.epochs}</Label>
- <Slider
- value={[config.epochs]}
- onValueChange={([value]) => updateConfig("epochs", value)}
- min={10}
- max={200}
- step={10}
- disabled={trainingState.isTraining}
- className="w-full"
- />
- <p className="text-xs text-muted-foreground">完整遍历训练数据的次数</p>
- </div>
- {/* Training Controls */}
- <div className="flex gap-2 pt-4">
- {!trainingState.isTraining ? (
- <Button onClick={startTraining} className="flex-1">
- <Play className="w-4 h-4 mr-2" />
- 开始训练
- </Button>
- ) : (
- <>
- <Button onClick={pauseTraining} variant="outline" className="flex-1 bg-transparent">
- <Pause className="w-4 h-4 mr-2" />
- 暂停
- </Button>
- <Button onClick={stopTraining} variant="destructive" className="flex-1">
- <Square className="w-4 h-4 mr-2" />
- 停止
- </Button>
- </>
- )}
- </div>
- </CardContent>
- </Card>
- {/* Training Progress and Metrics */}
- <div className="lg:col-span-2 space-y-6">
- {/* Training Status */}
- <Card>
- <CardHeader>
- <div className="flex items-center justify-between">
- <CardTitle className="font-serif">训练状态</CardTitle>
- <Badge
- variant={
- trainingState.phase === "completed"
- ? "default"
- : trainingState.phase === "training"
- ? "secondary"
- : trainingState.phase === "validating"
- ? "outline"
- : "secondary"
- }
- >
- {trainingState.phase === "idle"
- ? "待机中"
- : trainingState.phase === "initializing"
- ? "初始化中"
- : trainingState.phase === "training"
- ? "训练中"
- : trainingState.phase === "validating"
- ? "验证中"
- : "训练完成"}
- </Badge>
- </div>
- </CardHeader>
- <CardContent>
- <div className="space-y-4">
- {/* Overall Progress */}
- <div className="space-y-2">
- <div className="flex justify-between text-sm">
- <span>总体进度</span>
- <span>
- {trainingState.currentEpoch}/{trainingState.totalEpochs} 轮
- </span>
- </div>
- <Progress value={trainingState.progress} className="h-2" />
- </div>
- {/* Current Metrics */}
- {currentMetrics && (
- <div className="grid grid-cols-2 md:grid-cols-4 gap-4">
- <div className="text-center p-3 bg-muted rounded-lg">
- <div className="text-lg font-bold text-destructive">{currentMetrics.loss.toFixed(4)}</div>
- <div className="text-xs text-muted-foreground">训练损失</div>
- </div>
- <div className="text-center p-3 bg-muted rounded-lg">
- <div className="text-lg font-bold text-primary">
- {(currentMetrics.accuracy * 100).toFixed(1)}%
- </div>
- <div className="text-xs text-muted-foreground">训练准确率</div>
- </div>
- <div className="text-center p-3 bg-muted rounded-lg">
- <div className="text-lg font-bold text-destructive">{currentMetrics.valLoss.toFixed(4)}</div>
- <div className="text-xs text-muted-foreground">验证损失</div>
- </div>
- <div className="text-center p-3 bg-muted rounded-lg">
- <div className="text-lg font-bold text-primary">
- {(currentMetrics.valAccuracy * 100).toFixed(1)}%
- </div>
- <div className="text-xs text-muted-foreground">验证准确率</div>
- </div>
- </div>
- )}
- </div>
- </CardContent>
- </Card>
- {/* Training Charts */}
- <Card>
- <CardHeader>
- <CardTitle className="font-serif">训练曲线</CardTitle>
- <CardDescription>实时监控模型训练过程中的性能指标</CardDescription>
- </CardHeader>
- <CardContent>
- <Tabs defaultValue="loss" className="w-full">
- <TabsList className="grid w-full grid-cols-2">
- <TabsTrigger value="loss">损失函数</TabsTrigger>
- <TabsTrigger value="accuracy">准确率</TabsTrigger>
- </TabsList>
- <TabsContent value="loss" className="space-y-4">
- <div className="h-64 bg-muted/30 rounded-lg p-4 relative overflow-hidden">
- <div className="absolute inset-4">
- <svg className="w-full h-full" viewBox="0 0 400 200" preserveAspectRatio="none">
- {/* Grid lines */}
- <defs>
- <pattern id="grid" width="40" height="20" patternUnits="userSpaceOnUse">
- <path
- d="M 40 0 L 0 0 0 20"
- fill="none"
- stroke="currentColor"
- strokeWidth="0.5"
- opacity="0.1"
- />
- </pattern>
- </defs>
- <rect width="100%" height="100%" fill="url(#grid)" />
- {/* Loss curves */}
- {trainingState.metrics.length > 1 && (
- <>
- {/* Training loss */}
- <polyline
- fill="none"
- stroke="rgb(239 68 68)"
- strokeWidth="2"
- points={trainingState.metrics
- .map(
- (m, i) =>
- `${(i / Math.max(trainingState.metrics.length - 1, 1)) * 400},${200 - (m.loss / maxLoss) * 180}`,
- )
- .join(" ")}
- />
- {/* Validation loss */}
- <polyline
- fill="none"
- stroke="rgb(249 115 22)"
- strokeWidth="2"
- strokeDasharray="5,5"
- points={trainingState.metrics
- .map(
- (m, i) =>
- `${(i / Math.max(trainingState.metrics.length - 1, 1)) * 400},${200 - (m.valLoss / maxLoss) * 180}`,
- )
- .join(" ")}
- />
- </>
- )}
- </svg>
- </div>
- <div className="absolute bottom-2 left-4 flex gap-4 text-xs">
- <div className="flex items-center gap-1">
- <div className="w-3 h-0.5 bg-red-500"></div>
- <span>训练损失</span>
- </div>
- <div className="flex items-center gap-1">
- <div className="w-3 h-0.5 bg-orange-500 border-dashed"></div>
- <span>验证损失</span>
- </div>
- </div>
- </div>
- </TabsContent>
- <TabsContent value="accuracy" className="space-y-4">
- <div className="h-64 bg-muted/30 rounded-lg p-4 relative overflow-hidden">
- <div className="absolute inset-4">
- <svg className="w-full h-full" viewBox="0 0 400 200" preserveAspectRatio="none">
- <rect width="100%" height="100%" fill="url(#grid)" />
- {/* Accuracy curves */}
- {trainingState.metrics.length > 1 && (
- <>
- {/* Training accuracy */}
- <polyline
- fill="none"
- stroke="rgb(34 197 94)"
- strokeWidth="2"
- points={trainingState.metrics
- .map(
- (m, i) =>
- `${(i / Math.max(trainingState.metrics.length - 1, 1)) * 400},${200 - m.accuracy * 180}`,
- )
- .join(" ")}
- />
- {/* Validation accuracy */}
- <polyline
- fill="none"
- stroke="rgb(59 130 246)"
- strokeWidth="2"
- strokeDasharray="5,5"
- points={trainingState.metrics
- .map(
- (m, i) =>
- `${(i / Math.max(trainingState.metrics.length - 1, 1)) * 400},${200 - m.valAccuracy * 180}`,
- )
- .join(" ")}
- />
- </>
- )}
- </svg>
- </div>
- <div className="absolute bottom-2 left-4 flex gap-4 text-xs">
- <div className="flex items-center gap-1">
- <div className="w-3 h-0.5 bg-green-500"></div>
- <span>训练准确率</span>
- </div>
- <div className="flex items-center gap-1">
- <div className="w-3 h-0.5 bg-blue-500 border-dashed"></div>
- <span>验证准确率</span>
- </div>
- </div>
- </div>
- </TabsContent>
- </Tabs>
- </CardContent>
- </Card>
- </div>
- </div>
- {/* Next Step Button */}
- <div className="flex justify-end">
- <Button size="lg" disabled={trainingState.phase !== "completed"} className="font-semibold" onClick={handleNext}>
- 继续到图像生成
- <Zap className="w-4 h-4 ml-2" />
- </Button>
- </div>
- </div>
- )
- }
|