diff --git a/apps/web/src/pages/ModelsPage.tsx b/apps/web/src/pages/ModelsPage.tsx index 7c5f501..f0bdcbd 100644 --- a/apps/web/src/pages/ModelsPage.tsx +++ b/apps/web/src/pages/ModelsPage.tsx @@ -1,5 +1,5 @@ import { useState } from 'react'; -import { Play, Rocket, Globe, ChevronDown, ChevronUp, Beaker, Shield, Scissors, Wrench, Zap, BarChart3 } from 'lucide-react'; +import { Play, Rocket, Globe, ChevronDown, ChevronUp, Beaker, Shield, Zap, BarChart3 } from 'lucide-react'; import { TutorialHint } from '@/components/game/TutorialHint'; import { ConfirmModal } from '@/components/common/ConfirmModal'; import { useGameStore } from '@/store'; @@ -9,10 +9,14 @@ import { ALIGNMENT_METHODS, QUANTIZATION_CONFIGS, PARAMETER_OPTIONS, + SIZE_TIER_MAP, + SIZE_TIER_LABELS, + SFT_SPECIALIZATION_BONUSES, } from '@ai-tycoon/shared'; import type { ModelArchitecture, DataMixAllocation, SFTSpecialization, AlignmentMethod, DataDomain, QuantizationLevel, BaseModel, ModelVariant, BenchmarkResult, + SizeTier, ModelFamily, } from '@ai-tycoon/shared'; import { BENCHMARKS } from '@ai-tycoon/game-engine'; @@ -56,12 +60,8 @@ export function ModelsPage() { const totalData = useGameStore((s) => s.data.totalTrainingTokens); const currentEra = useGameStore((s) => s.meta.currentEra); const startTrainingPipeline = useGameStore((s) => s.startTrainingPipeline); - const configureSFT = useGameStore((s) => s.configureSFT); - const configureAlignment = useGameStore((s) => s.configureAlignment); const deployModel = useGameStore((s) => s.deployModel); const deployVariant = useGameStore((s) => s.deployVariant); - const createDistillation = useGameStore((s) => s.createDistillation); - const createFineTune = useGameStore((s) => s.createFineTune); const createQuantization = useGameStore((s) => s.createQuantization); const startEvaluation = useGameStore((s) => s.startEvaluation); const setTrainingAllocation = useGameStore((s) => s.setTrainingAllocation); @@ -80,6 +80,15 @@ export function ModelsPage() { const [dataMix, setDataMix] = useState({ ...DEFAULT_DATA_MIX }); const [dataMixPreset, setDataMixPreset] = useState('balanced'); + // New model lifecycle state + const [familyMode, setFamilyMode] = useState<'new' | 'existing'>('new'); + const [selectedFamilyId, setSelectedFamilyId] = useState(null); + const [isPointRelease, setIsPointRelease] = useState(false); + const [sourceModelId, setSourceModelId] = useState(null); + const [sftSpecs, setSftSpecs] = useState(['general']); + const [alignMethod, setAlignMethod] = useState('rlhf'); + const [safetyWeight, setSafetyWeight] = useState(0.5); + const trainingFlops = totalFlops * trainingAlloc; const estimatedTicks = trainingFlops > 0 ? Math.max(30, Math.ceil(180 / (1 + trainingFlops * 0.1))) : Infinity; const estimatedCapability = Math.min(95, Math.sqrt(trainingFlops) * 5 + Math.log10(1 + totalData / 1e8) * 10); @@ -95,9 +104,26 @@ export function ModelsPage() { const currentEraIdx = eraOrder.indexOf(currentEra); const availableBenchmarks = BENCHMARKS.filter(b => eraOrder.indexOf(b.unlockedAtEra) <= currentEraIdx); + const hasAlignmentResearch = completedResearch.some(r => + r === 'alignment-research' || r === 'interpretability' || r === 'constitutional-ai', + ); + + // Computed size tier + const sizeTier: SizeTier = SIZE_TIER_MAP[parameterCount] ?? 'small'; + + // Model name preview + const familyNameForPreview = familyMode === 'new' + ? (modelName.trim() || `Family ${families.length + 1}`) + : (families.find(f => f.id === selectedFamilyId)?.name ?? 'Family'); + const nextVersion = (() => { + if (!isPointRelease || !sourceModelId) return 1.0; + const src = baseModels.find(m => m.id === sourceModelId); + return src ? Math.round((src.version + 0.1) * 10) / 10 : 1.0; + })(); + const modelNamePreview = `${familyNameForPreview} ${SIZE_TIER_LABELS[sizeTier]} v${nextVersion.toFixed(1)}`; + const handleStartTraining = () => { if (trainingFlops === 0) return; - const name = modelName.trim() || `Model v${families.length + 1}`; const architecture: ModelArchitecture = { type: archType, @@ -109,14 +135,23 @@ export function ModelsPage() { }; startTrainingPipeline({ - modelName: name, + ...(familyMode === 'new' + ? { familyName: modelName.trim() || `Family ${families.length + 1}` } + : { familyId: selectedFamilyId! }), architecture, dataMix, allocatedComputeFraction: 1.0, targetTokens: totalData, totalTicks: estimatedTicks, + sftSpecializations: sftSpecs, + alignmentMethod: alignMethod, + alignmentSafetyWeight: safetyWeight, + isPointRelease, + sourceModelId: sourceModelId ?? undefined, }); setModelName(''); + setIsPointRelease(false); + setSourceModelId(null); }; const handlePresetChange = (presetKey: string) => { @@ -137,10 +172,6 @@ export function ModelsPage() { setDataMixPreset('custom'); }; - const hasAlignmentResearch = completedResearch.some(r => - r === 'alignment-research' || r === 'interpretability' || r === 'constitutional-ai', - ); - return (

Models

@@ -243,8 +274,8 @@ export function ModelsPage() {
- - + +
@@ -259,19 +290,6 @@ export function ModelsPage() { : `ETA: ${formatDuration(stage.totalTicks - stage.progressTicks)}`}
- {pipeline.currentStage === 'pretraining' && !pipeline.stages.sft && !pipeline.stages.alignment && ( -
- - - Post-training not configured —{' '} - - {' '}or they'll be skipped. - -
- )} - {isExpanded && (
{pipeline.currentStage === 'pretraining' && ( @@ -281,10 +299,6 @@ export function ModelsPage() {
)} - {pipeline.currentStage === 'pretraining' && !pipeline.stages.pretraining.isComplete && (!pipeline.stages.sft || !pipeline.stages.alignment) && ( - - )} - {recentEvents.length > 0 && (
Recent Events @@ -357,17 +371,60 @@ export function ModelsPage() { {/* Train New Model */} {modelsTab === 'train' &&

Train New Model

+ + {isPointRelease && sourceModelId && ( +
+
+ Point Release — iterating on {baseModels.find(m => m.id === sourceModelId)?.name ?? 'model'} + (40% training time) +
+ +
+ )} +
-
-
- + + {/* Family selector */} +
+ +
+ + +
+ {familyMode === 'new' ? ( setModelName(e.target.value)} - placeholder={`Model v${families.length + 1}`} + placeholder={`Family ${families.length + 1}`} className="w-full bg-surface-800 border border-surface-600 rounded px-3 py-2 text-sm focus:outline-none focus:ring-2 focus:ring-accent/50" /> -
+ ) : ( + + )} +
+ + {/* Architecture & Parameters */} +
@@ -385,20 +442,6 @@ export function ModelsPage() {
-
-
-
- - -
setParameterCount(Number(e.target.value))} + className="flex-1 bg-surface-800 border border-surface-600 rounded px-3 py-2 text-sm focus:outline-none focus:ring-2 focus:ring-accent/50" + > + {PARAMETER_OPTIONS.map(p => ( + + ))} + + + {SIZE_TIER_LABELS[sizeTier]} + +
+
+ {/* Data Mix */}
@@ -448,6 +510,94 @@ export function ModelsPage() {
+ {/* SFT Configuration */} +
+
+ + +
+
+ {SFT_OPTIONS.map(opt => ( + + ))} +
+ {sftSpecs.length > 0 && ( +
+ Bonus preview: + {sftSpecs.map(spec => { + const bonuses = SFT_SPECIALIZATION_BONUSES[spec]; + if (!bonuses) return null; + const positives = Object.entries(bonuses).filter(([, v]) => v > 0).map(([k, v]) => `${k} +${v}`); + const negatives = Object.entries(bonuses).filter(([, v]) => v < 0).map(([k, v]) => `${k} ${v}`); + return ( + + {spec} + {positives.length > 0 && {positives.join(', ')}} + {negatives.length > 0 && {negatives.join(', ')}} + + ); + })} +
+ )} +
+ + {/* Alignment Configuration */} +
+
+ + +
+ {hasAlignmentResearch ? ( +
+
+ {(Object.keys(ALIGNMENT_METHODS) as AlignmentMethod[]).map(method => { + const isAvailable = completedResearch.includes(ALIGNMENT_METHODS[method].requiredResearch); + return ( + + ); + })} +
+
+ Safety + setSafetyWeight(Number(e.target.value) / 100)} + className="flex-1 accent-accent h-1" /> + Helpful + {Math.round(safetyWeight * 100)}% +
+
+ ) : ( +
+ Requires alignment research — defaults to RLHF +
+ )} +
+ {/* Stats */}
@@ -471,14 +621,22 @@ export function ModelsPage() { Estimated capability: {estimatedCapability.toFixed(1)}/100 {archType === 'moe' && (+15% MoE bonus)}
+ + {/* Model name preview */} +
+ Model name: + {modelNamePreview} +
+ + {/* Start button */}
{trainingFlops === 0 && totalFlops === 0 && (

Build a data center and order racks first

@@ -495,63 +653,77 @@ export function ModelsPage() {

Model Families

{families.map(family => { - const base = baseModels.find(m => m.familyId === family.id); + const familyModels = baseModels.filter(m => m.familyId === family.id); const variants = family.variants; const isExpanded = expandedModel === family.id; - if (!base) return null; - return (
-
+ {/* Family header */} +
- -
-

{family.name} Gen {family.generation}

-
- {base.architecture.totalParameters}B {base.architecture.type.toUpperCase()} · Cap: {base.rawCapability.toFixed(1)} · Safety: {base.safetyProfile.overallSafety.toFixed(0)} - {variants.length > 0 && · {variants.length} variant{variants.length > 1 ? 's' : ''}} -
-
+

{family.name} Gen {family.generation}

- deployModel(base.id)} onOpenSource={() => openSourceModel(base.id)} /> +
- {isExpanded && ( + {/* Model rows */} + {familyModels.map(model => ( +
+
+ {model.name} + {model.architecture.totalParameters}B + Cap: {model.rawCapability.toFixed(1)} +
+
+ {model.isDeployed ? ( + Deployed + ) : ( + + )} + +
+
+ ))} + + {familyModels.length === 0 && ( +

Training in progress...

+ )} + + {/* Expanded: details, quantize, eval for each model */} + {isExpanded && familyModels.length > 0 && (
- {/* Base model details */} - + {familyModels.map(model => ( +
+
{model.name}
+ + + +
+ ))} - {/* Variant creation */} - - - {/* Benchmark evaluation */} - - - {/* Variants tree */} {variants.length > 0 && (
- Variants + Quantized Variants {variants.map(variant => ( void; - onFineTune: (baseModelId: string, spec: SFTSpecialization, name: string) => void; onQuantize: (baseModelId: string, level: QuantizationLevel, name: string) => void; }) { const [showCreator, setShowCreator] = useState(false); - const [creatorTab, setCreatorTab] = useState<'distill' | 'finetune' | 'quantize'>('quantize'); - const [distillParams, setDistillParams] = useState(7); - const [ftSpec, setFtSpec] = useState('code'); const [quantLevel, setQuantLevel] = useState('int8'); - - const hasDistillation = completedResearch.includes('distillation'); const hasQuantization = completedResearch.includes('quantization') || completedResearch.includes('model-compression'); - const smallerParams = PARAMETER_OPTIONS.filter(p => p < model.architecture.totalParameters); + if (!hasQuantization) return null; if (!showCreator) { return ( @@ -778,7 +943,7 @@ function VariantCreator({ model, completedResearch, onDistill, onFineTune, onQua onClick={() => setShowCreator(true)} className="flex items-center gap-1 text-xs text-accent hover:text-accent-light" > - Create Variant + Create Quantized Variant ); } @@ -786,91 +951,31 @@ function VariantCreator({ model, completedResearch, onDistill, onFineTune, onQua return (
- Create Variant + Quantize {model.name}
-
- {hasDistillation && ( - - )} - - {hasQuantization && ( - - )} -
- - {creatorTab === 'distill' && hasDistillation && ( -
-
- - -
-
- Retention: ~{((0.70 + (distillParams / model.architecture.totalParameters) * 0.25) * 100).toFixed(0)}% quality -
- -
- )} - - {creatorTab === 'finetune' && ( -
-
- -
- {SFT_OPTIONS.map(opt => ( - - ))} -
+ ); + })}
-
- )} - - {creatorTab === 'quantize' && hasQuantization && ( -
-
- -
- {(Object.keys(QUANTIZATION_CONFIGS) as QuantizationLevel[]).map(level => { - const cfg = QUANTIZATION_CONFIGS[level]; - return ( - - ); - })} -
-
- -
- )} + +
); } @@ -961,11 +1066,6 @@ function VariantCard({ variant, familyId, benchmarkResults, availableBenchmarks, const [isExpanded, setIsExpanded] = useState(false); const variantResults = benchmarkResults.filter(r => r.modelId === variant.id); - const typeLabel = variant.variantType === 'distilled' ? 'Distilled' - : variant.variantType === 'fine-tuned' ? 'Fine-tuned' : 'Quantized'; - const typeColor = variant.variantType === 'distilled' ? 'text-purple-400' - : variant.variantType === 'fine-tuned' ? 'text-yellow-400' : 'text-green-400'; - return (
@@ -975,9 +1075,8 @@ function VariantCard({ variant, familyId, benchmarkResults, availableBenchmarks,
{variant.name} - {typeLabel} + Quantized {variant.quantization && {variant.quantization.toUpperCase()}} - {variant.finetuneSpecialization && {variant.finetuneSpecialization}}
@@ -1108,16 +1207,15 @@ function BenchmarkLeaderboard({ benchmarkResults, baseModels, families, availabl ); } -function StageBar({ label, active, complete, progress, configured = true }: { - label: string; active: boolean; complete: boolean; progress: number; configured?: boolean; +function StageBar({ label, active, complete, progress }: { + label: string; active: boolean; complete: boolean; progress: number; }) { return (
-
- {label}{!configured && ' (skip)'} +
+ {label}
); } - -function PostTrainingConfig({ pipelineId, hasAlignmentResearch, completedResearch, configureSFT, configureAlignment, sftConfigured, alignmentConfigured }: { - pipelineId: string; - hasAlignmentResearch: boolean; - completedResearch: string[]; - configureSFT: (pipelineId: string, specializations: SFTSpecialization[]) => void; - configureAlignment: (pipelineId: string, method: AlignmentMethod, safetyWeight: number) => void; - sftConfigured: boolean; - alignmentConfigured: boolean; -}) { - const [selectedSpecs, setSelectedSpecs] = useState(['general']); - const [alignMethod, setAlignMethod] = useState('rlhf'); - const [safetyWeight, setSafetyWeight] = useState(0.5); - - return ( -
-
Configure Post-Training (optional)
- - {!sftConfigured ? ( -
-
- Supervised Fine-Tuning -
-
- {SFT_OPTIONS.map(opt => ( - - ))} -
- -
- ) : ( -
- SFT configured -
- )} - - {!alignmentConfigured ? ( - hasAlignmentResearch ? ( -
-
- Alignment -
-
- {(Object.keys(ALIGNMENT_METHODS) as AlignmentMethod[]).map(method => { - const isAvailable = completedResearch.includes(ALIGNMENT_METHODS[method].requiredResearch); - return ( - - ); - })} -
-
- Safety - setSafetyWeight(Number(e.target.value) / 100)} - className="flex-1 accent-accent h-1" /> - Helpful -
- -
- ) : ( -
- Alignment requires research -
- ) - ) : ( -
- Alignment configured -
- )} -
- ); -} diff --git a/apps/web/src/store/index.ts b/apps/web/src/store/index.ts index 5fb7690..9f925db 100644 --- a/apps/web/src/store/index.ts +++ b/apps/web/src/store/index.ts @@ -13,7 +13,7 @@ import type { CoolingType, NetworkFabric, FundingRoundType, OverloadPolicy, TrainingPipeline, ModelFamily, DataMixAllocation, - ModelArchitecture, + ModelArchitecture, AlignmentMethod, SizeTier, SFTSpecialization, QuantizationLevel, VariantCreationJob, EvalJob, ConsumerTierId, ApiTierId, @@ -36,9 +36,10 @@ import { COOLING_TYPE_CONFIGS, COOLING_ORDER, NETWORK_FABRIC_CONFIGS, FABRIC_ORDER, DEFAULT_DATA_MIX, MAX_CONCURRENT_TRAINING, - DISTILLATION_TIME_FRACTION, DISTILLATION_COMPUTE_FRACTION, - FINETUNE_TIME_FRACTION, FINETUNE_COMPUTE_FRACTION, QUANTIZATION_TICKS, + SFT_TIME_FRACTION, ALIGNMENT_TIME_FRACTION, + SIZE_TIER_MAP, SIZE_TIER_LABELS, + POINT_RELEASE_TIME_FRACTION, POINT_RELEASE_MAX_VERSION, } from '@ai-tycoon/shared'; import { emptyDCNetworkSummary, emptyCampusNetworkSummary, emptyClusterNetworkSummary, @@ -115,11 +116,21 @@ interface Actions { upgradeDataCenter: (dataCenterId: string, upgrade: 'cooling' | 'redundancy') => void; upgradeCoolingType: (dataCenterId: string, targetCooling: CoolingType) => void; upgradeNetworkFabric: (dataCenterId: string, targetFabric: NetworkFabric) => void; - startTrainingPipeline: (config: { modelName: string; architecture: ModelArchitecture; dataMix: DataMixAllocation; allocatedComputeFraction: number; targetTokens: number; totalTicks: number }) => void; - configureSFT: (pipelineId: string, specializations: import('@ai-tycoon/shared').SFTSpecialization[]) => void; - configureAlignment: (pipelineId: string, method: import('@ai-tycoon/shared').AlignmentMethod, safetyWeight: number) => void; - createDistillation: (baseModelId: string, targetParameters: number, variantName: string) => void; - createFineTune: (baseModelId: string, specialization: SFTSpecialization, variantName: string) => void; + startTrainingPipeline: (config: { + familyId?: string; + familyName?: string; + architecture: ModelArchitecture; + dataMix: DataMixAllocation; + allocatedComputeFraction: number; + targetTokens: number; + totalTicks: number; + sftSpecializations: SFTSpecialization[]; + alignmentMethod: AlignmentMethod; + alignmentSafetyWeight: number; + isPointRelease?: boolean; + sourceModelId?: string; + }) => void; + startPointRelease: (baseModelId: string) => void; createQuantization: (baseModelId: string, level: QuantizationLevel, variantName: string) => void; startEvaluation: (modelId: string, benchmarkIds: string[]) => void; deployModel: (modelId: string) => void; @@ -917,29 +928,52 @@ export const useGameStore = create()( startTrainingPipeline: (config) => { let created = false; + let toastName = ''; set((s) => { const activeCount = s.models.activeTrainingPipelines.filter(p => p.status === 'active' || p.status === 'stalled').length; const maxSlots = MAX_CONCURRENT_TRAINING[s.meta.currentEra] ?? 1; if (activeCount >= maxSlots) return s; created = true; - const familyId = uuid(); - const pipelineId = uuid(); - const generation = s.models.families.length + 1; - const family: ModelFamily = { - id: familyId, - name: config.modelName, - generation, - baseModelId: null, - variants: [], - createdAtTick: s.meta.tickCount, - }; + let familyId: string; + let updatedFamilies = [...s.models.families]; + + if (config.familyId) { + familyId = config.familyId; + } else { + familyId = uuid(); + const generation = s.models.families.length + 1; + const family: ModelFamily = { + id: familyId, + name: config.familyName ?? 'Model', + generation, + baseModelIds: [], + variants: [], + createdAtTick: s.meta.tickCount, + }; + updatedFamilies = [...updatedFamilies, family]; + } + + const sizeTier: SizeTier = SIZE_TIER_MAP[config.architecture.totalParameters] ?? 'small'; + const familyName = config.familyName ?? updatedFamilies.find(f => f.id === familyId)?.name ?? 'Model'; + const version = config.isPointRelease && config.sourceModelId + ? (() => { + const src = s.models.baseModels.find(m => m.id === config.sourceModelId); + return src ? Math.round((src.version + 0.1) * 10) / 10 : 1.0; + })() + : 1.0; + const modelName = `${familyName} ${SIZE_TIER_LABELS[sizeTier]} v${version.toFixed(1)}`; + toastName = modelName; + + const baseTotalTicks = config.isPointRelease + ? Math.ceil(config.totalTicks * POINT_RELEASE_TIME_FRACTION) + : config.totalTicks; const pipeline: TrainingPipeline = { - id: pipelineId, + id: uuid(), familyId, - modelName: config.modelName, + modelName, architecture: config.architecture, dataMix: config.dataMix, currentStage: 'pretraining', @@ -949,130 +983,70 @@ export const useGameStore = create()( processedTokens: 0, computeAllocated: 0, progressTicks: 0, - totalTicks: config.totalTicks, + totalTicks: baseTotalTicks, lossValue: 10, chinchillaRatio: config.targetTokens / (config.architecture.totalParameters * 1e9), isComplete: false, }, - sft: null, - alignment: null, + sft: { + specializations: config.sftSpecializations, + progressTicks: 0, + totalTicks: Math.ceil(baseTotalTicks * SFT_TIME_FRACTION), + isComplete: false, + }, + alignment: { + method: config.alignmentMethod, + safetyWeight: config.alignmentSafetyWeight, + helpfulnessWeight: 1 - config.alignmentSafetyWeight, + progressTicks: 0, + totalTicks: Math.ceil(baseTotalTicks * ALIGNMENT_TIME_FRACTION), + isComplete: false, + }, }, status: 'active', allocatedComputeFraction: config.allocatedComputeFraction, events: [], startedAtTick: s.meta.tickCount, + sizeTier, + isPointRelease: config.isPointRelease ?? false, + sourceModelId: config.sourceModelId ?? null, }; return { models: { ...s.models, - families: [...s.models.families, family], + families: updatedFamilies, activeTrainingPipelines: [...s.models.activeTrainingPipelines, pipeline], }, }; }); if (created) { - get().addNotification({ title: 'Training Started', message: `${config.modelName} pre-training has begun.`, type: 'info', tick: get().meta.tickCount }); + get().addNotification({ title: 'Training Started', message: `${toastName} training has begun.`, type: 'info', tick: get().meta.tickCount }); set({ modelsTab: 'overview' as ModelsTab }); } }, - configureSFT: (pipelineId, specializations) => { - set((s) => ({ - models: { - ...s.models, - activeTrainingPipelines: s.models.activeTrainingPipelines.map(p => - p.id === pipelineId ? { - ...p, - stages: { - ...p.stages, - sft: { - specializations, - progressTicks: 0, - totalTicks: Math.ceil(p.stages.pretraining.totalTicks * 0.10), - isComplete: false, - }, - }, - } : p, - ), - }, - })); - get().addNotification({ title: 'SFT Configured', message: `${specializations.join(', ')} specializations enabled.`, type: 'success', tick: get().meta.tickCount }); - }, + startPointRelease: (baseModelId) => { + const s = get(); + const base = s.models.baseModels.find(m => m.id === baseModelId); + if (!base) return; + if (base.version >= POINT_RELEASE_MAX_VERSION) return; + const family = s.models.families.find(f => f.id === base.familyId); + if (!family) return; - configureAlignment: (pipelineId, method, safetyWeight) => { - set((s) => ({ - models: { - ...s.models, - activeTrainingPipelines: s.models.activeTrainingPipelines.map(p => - p.id === pipelineId ? { - ...p, - stages: { - ...p.stages, - alignment: { - method, - safetyWeight, - helpfulnessWeight: 1 - safetyWeight, - progressTicks: 0, - totalTicks: Math.ceil(p.stages.pretraining.totalTicks * 0.08), - isComplete: false, - }, - }, - } : p, - ), - }, - })); - get().addNotification({ title: 'Alignment Configured', message: `${method.toUpperCase()} alignment enabled.`, type: 'success', tick: get().meta.tickCount }); - }, - - createDistillation: (baseModelId, targetParameters, variantName) => { - let created = false; - set((s) => { - const base = s.models.baseModels.find(m => m.id === baseModelId); - if (!base) return s; - created = true; - const job: VariantCreationJob = { - id: uuid(), - familyId: base.familyId, - baseModelId, - jobType: 'distillation', - config: { targetParameters, targetArchitecture: base.architecture.type, variantName }, - progressTicks: 0, - totalTicks: Math.ceil(base.trainingCostTotal > 0 ? DISTILLATION_TIME_FRACTION * 120 : 30), - allocatedComputeFraction: DISTILLATION_COMPUTE_FRACTION, - status: 'active', - }; - return { models: { ...s.models, variantJobs: [...s.models.variantJobs, job] } }; + get().startTrainingPipeline({ + familyId: base.familyId, + architecture: base.architecture, + dataMix: base.dataMix, + allocatedComputeFraction: 1.0, + targetTokens: base.architecture.totalParameters * 20e9, + totalTicks: Math.ceil(base.architecture.totalParameters * 2 + 60), + sftSpecializations: base.sftSpecializations, + alignmentMethod: base.alignmentMethod ?? 'rlhf', + alignmentSafetyWeight: 0.5, + isPointRelease: true, + sourceModelId: baseModelId, }); - if (created) { - get().addNotification({ title: 'Distillation Started', message: `${variantName} distillation in progress.`, type: 'info', tick: get().meta.tickCount }); - set({ modelsTab: 'overview' as ModelsTab }); - } - }, - - createFineTune: (baseModelId, specialization, variantName) => { - let created = false; - set((s) => { - const base = s.models.baseModels.find(m => m.id === baseModelId); - if (!base) return s; - created = true; - const job: VariantCreationJob = { - id: uuid(), - familyId: base.familyId, - baseModelId, - jobType: 'fine-tuning', - config: { specialization, datasetIds: [], variantName }, - progressTicks: 0, - totalTicks: Math.ceil(FINETUNE_TIME_FRACTION * 120), - allocatedComputeFraction: FINETUNE_COMPUTE_FRACTION, - status: 'active', - }; - return { models: { ...s.models, variantJobs: [...s.models.variantJobs, job] } }; - }); - if (created) { - get().addNotification({ title: 'Fine-Tuning Started', message: `${variantName} fine-tuning in progress.`, type: 'info', tick: get().meta.tickCount }); - set({ modelsTab: 'overview' as ModelsTab }); - } }, createQuantization: (baseModelId, level, variantName) => { diff --git a/packages/game-engine/src/systems/modelSystem.ts b/packages/game-engine/src/systems/modelSystem.ts index 437cf6e..cbd90a1 100644 --- a/packages/game-engine/src/systems/modelSystem.ts +++ b/packages/game-engine/src/systems/modelSystem.ts @@ -7,8 +7,6 @@ import type { import { BENCHMARKS } from '../data/benchmarks'; import { uuid, VRAM_REQUIREMENTS_BY_GENERATION, - SFT_TIME_FRACTION, SFT_COMPUTE_FRACTION, - ALIGNMENT_TIME_FRACTION, ALIGNMENT_COMPUTE_FRACTION, MOE_CAPABILITY_MULTIPLIER, MOE_SPEED_MULTIPLIER, EVENT_BASE_PROBABILITY, LOSS_SPIKE_DELAY_MIN, LOSS_SPIKE_DELAY_MAX, @@ -18,8 +16,8 @@ import { ALIGNMENT_METHODS, SFT_SPECIALIZATION_BONUSES, QUANTIZATION_CONFIGS, - DISTILLATION_BASE_RETENTION, - QUANTIZATION_TICKS, + POINT_RELEASE_CAPABILITY_GAIN, + SIZE_TIER_LABELS, } from '@ai-tycoon/shared'; import type { ResearchBonuses } from './researchBonuses'; @@ -101,60 +99,25 @@ export function processModels(state: GameState, researchBonuses?: ResearchBonuse stage.computeAllocated = effectiveFlops; stage.lossValue = Math.max(0.01, 10 * Math.exp(-stage.progressTicks / stage.totalTicks * 3)); - const progressRatio = stage.progressTicks / stage.totalTicks; - if (progressRatio >= 0.75 && progressRatio < 0.78 && !pipeline.stages.sft && !pipeline.stages.alignment) { - notifications.push({ - title: 'Post-Training Reminder', - message: `${pipeline.modelName} is 75% pre-trained. Configure SFT/Alignment now or they'll be skipped!`, - type: 'warning', - action: { label: 'Configure Now', page: 'models', modelsTab: 'overview' }, - }); - } - if (stage.progressTicks >= stage.totalTicks) { stage.isComplete = true; stage.progressTicks = stage.totalTicks; - - if (updated.stages.sft) { - updated.currentStage = 'sft'; - notifications.push({ title: 'Pre-training Complete', message: `${pipeline.modelName}: Moving to supervised fine-tuning.`, type: 'info' }); - } else if (updated.stages.alignment) { - updated.currentStage = 'alignment'; - notifications.push({ title: 'Pre-training Complete', message: `${pipeline.modelName}: Moving to alignment.`, type: 'info' }); - } else { - const model = createBaseModel(updated, state, researchBonuses); - baseModels = [...baseModels, model]; - families = families.map(f => - f.id === pipeline.familyId ? { ...f, baseModelId: model.id } : f, - ); - completedModels.push(model); - updated.status = 'completed'; - } + updated.currentStage = 'sft'; + notifications.push({ title: 'Pre-training Complete', message: `${pipeline.modelName}: Moving to supervised fine-tuning.`, type: 'info' }); } updated = { ...updated, stages: { ...updated.stages, pretraining: stage } }; - } else if (pipeline.currentStage === 'sft' && pipeline.stages.sft) { + } else if (pipeline.currentStage === 'sft') { const stage = { ...pipeline.stages.sft }; stage.progressTicks += speedMultiplier; if (stage.progressTicks >= stage.totalTicks) { stage.isComplete = true; stage.progressTicks = stage.totalTicks; - - if (updated.stages.alignment) { - updated.currentStage = 'alignment'; - notifications.push({ title: 'SFT Complete', message: `${pipeline.modelName}: Moving to alignment.`, type: 'info' }); - } else { - const model = createBaseModel(updated, state, researchBonuses); - baseModels = [...baseModels, model]; - families = families.map(f => - f.id === pipeline.familyId ? { ...f, baseModelId: model.id } : f, - ); - completedModels.push(model); - updated.status = 'completed'; - } + updated.currentStage = 'alignment'; + notifications.push({ title: 'SFT Complete', message: `${pipeline.modelName}: Moving to alignment.`, type: 'info' }); } updated = { ...updated, stages: { ...updated.stages, sft: stage } }; - } else if (pipeline.currentStage === 'alignment' && pipeline.stages.alignment) { + } else if (pipeline.currentStage === 'alignment') { const stage = { ...pipeline.stages.alignment }; stage.progressTicks += speedMultiplier; @@ -165,7 +128,7 @@ export function processModels(state: GameState, researchBonuses?: ResearchBonuse const model = createBaseModel(updated, state, researchBonuses); baseModels = [...baseModels, model]; families = families.map(f => - f.id === pipeline.familyId ? { ...f, baseModelId: model.id } : f, + f.id === pipeline.familyId ? { ...f, baseModelIds: [...f.baseModelIds, model.id] } : f, ); completedModels.push(model); updated.status = 'completed'; @@ -320,79 +283,93 @@ function createBaseModel( const dataTokens = pipeline.stages.pretraining.targetTokens; const params = architecture.totalParameters; - // Pillar 1: Parameters (0-30) — larger models have higher ceiling - const paramFactor = Math.min(30, Math.log2(1 + params) * 4.5); + const sourceModel = pipeline.isPointRelease && pipeline.sourceModelId + ? state.models.baseModels.find(m => m.id === pipeline.sourceModelId) + : null; - // Pillar 2: Compute (0-25) — compute relative to parameter count (Chinchilla scaling) - const computePerParam = compute / Math.max(1, params); - const computeFactor = Math.min(25, Math.sqrt(computePerParam) * 8); + let rawCapability: number; + let capabilities: ModelCapabilities; - // Pillar 3: Data (0-20) — token count with quality multiplier - const dataQualityMultiplier = 1 + (researchBonuses?.dataQualityBonus ?? 0); - const dataFactor = Math.min(20, Math.log10(1 + dataTokens / 1e8) * 8 * dataQualityMultiplier); + if (sourceModel) { + rawCapability = Math.min(98, sourceModel.rawCapability * (1 + POINT_RELEASE_CAPABILITY_GAIN)); - // Pillar 4: Research (0-20) — accumulated research knowledge - const capabilityResearchBonus = researchBonuses?.globalCapabilityBonus ?? 0; - const researchFactor = Math.min(20, capabilityResearchBonus + state.research.completedResearch.length * 0.5); - - let rawCapability = Math.min(95, paramFactor + computeFactor + dataFactor + researchFactor); - - if (architecture.type === 'moe') { - rawCapability = Math.min(98, rawCapability * MOE_CAPABILITY_MULTIPLIER); - } - - // MoE tradeoff: total params need full VRAM even though only active params run - // This is enforced in the UI/store when checking VRAM requirements - - const researcherQuality = state.talent.departments.research.effectiveness; - const contextBonus = Math.log2(Math.max(1, architecture.contextWindow / 4)) * 3; - const contextPenalty = Math.max(0, Math.log2(architecture.contextWindow / 8)) * 2; - - const capabilities: ModelCapabilities = { - reasoning: clamp(rawCapability * (0.6 + dataMix.scientific * 0.5 + dataMix.code * 0.3) * (1 + researcherQuality * 0.2)), - coding: clamp(rawCapability * (0.5 + dataMix.code * 1.0)), - creative: clamp(rawCapability * (0.4 + dataMix.books * 0.6 + dataMix.conversation * 0.3)), - math: clamp(rawCapability * (0.3 + dataMix.scientific * 0.7 + dataMix.code * 0.2)), - knowledge: clamp(rawCapability * (0.5 + dataMix.web * 0.3 + dataMix.books * 0.3) + contextBonus * 0.3), - multimodal: clamp(rawCapability * (dataMix.images * 0.5 + dataMix.video * 0.4 + dataMix.audio * 0.2)), - agents: clamp(rawCapability * (0.2 + dataMix.code * 0.3 + dataMix.conversation * 0.2) + contextBonus * 0.5), - speed: Math.max(1, 100 - params * 0.3 - contextPenalty + (researchBonuses?.inferenceEfficiencyBonus ?? 0) * 20 + (architecture.type === 'moe' ? MOE_SPEED_MULTIPLIER * 10 : 0)), - contextUtilization: Math.min(100, architecture.contextWindow * 0.4), - }; - - if (researchBonuses) { - capabilities.reasoning = clamp(capabilities.reasoning + researchBonuses.reasoningBonus); - capabilities.coding = clamp(capabilities.coding + researchBonuses.codingBonus); - capabilities.creative = clamp(capabilities.creative + researchBonuses.creativeBonus); - capabilities.multimodal = clamp(capabilities.multimodal + researchBonuses.multimodalBonus); - capabilities.agents = clamp(capabilities.agents + researchBonuses.agentsBonus); - } - - const breakthroughBonuses: Partial> = {}; - for (const event of pipeline.events) { - if ((event.type === 'breakthrough' || event.type === 'emergent_capability') && event.impact.capabilityDomain && event.impact.capabilityBonus) { - const domain = event.impact.capabilityDomain; - breakthroughBonuses[domain] = (breakthroughBonuses[domain] ?? 0) + event.impact.capabilityBonus; + capabilities = { ...sourceModel.capabilities }; + const boost = POINT_RELEASE_CAPABILITY_GAIN; + for (const key of Object.keys(capabilities) as (keyof ModelCapabilities)[]) { + if (key !== 'speed' && key !== 'contextUtilization') { + capabilities[key] = clamp(capabilities[key] * (1 + boost)); + } + } + } else { + const paramFactor = Math.min(30, Math.log2(1 + params) * 4.5); + + const computePerParam = compute / Math.max(1, params); + const computeFactor = Math.min(25, Math.sqrt(computePerParam) * 8); + + const dataQualityMultiplier = 1 + (researchBonuses?.dataQualityBonus ?? 0); + const dataFactor = Math.min(20, Math.log10(1 + dataTokens / 1e8) * 8 * dataQualityMultiplier); + + const capabilityResearchBonus = researchBonuses?.globalCapabilityBonus ?? 0; + const researchFactor = Math.min(20, capabilityResearchBonus + state.research.completedResearch.length * 0.5); + + rawCapability = Math.min(95, paramFactor + computeFactor + dataFactor + researchFactor); + + if (architecture.type === 'moe') { + rawCapability = Math.min(98, rawCapability * MOE_CAPABILITY_MULTIPLIER); + } + + const researcherQuality = state.talent.departments.research.effectiveness; + const contextBonus = Math.log2(Math.max(1, architecture.contextWindow / 4)) * 3; + const contextPenalty = Math.max(0, Math.log2(architecture.contextWindow / 8)) * 2; + + capabilities = { + reasoning: clamp(rawCapability * (0.6 + dataMix.scientific * 0.5 + dataMix.code * 0.3) * (1 + researcherQuality * 0.2)), + coding: clamp(rawCapability * (0.5 + dataMix.code * 1.0)), + creative: clamp(rawCapability * (0.4 + dataMix.books * 0.6 + dataMix.conversation * 0.3)), + math: clamp(rawCapability * (0.3 + dataMix.scientific * 0.7 + dataMix.code * 0.2)), + knowledge: clamp(rawCapability * (0.5 + dataMix.web * 0.3 + dataMix.books * 0.3) + contextBonus * 0.3), + multimodal: clamp(rawCapability * (dataMix.images * 0.5 + dataMix.video * 0.4 + dataMix.audio * 0.2)), + agents: clamp(rawCapability * (0.2 + dataMix.code * 0.3 + dataMix.conversation * 0.2) + contextBonus * 0.5), + speed: Math.max(1, 100 - params * 0.3 - contextPenalty + (researchBonuses?.inferenceEfficiencyBonus ?? 0) * 20 + (architecture.type === 'moe' ? MOE_SPEED_MULTIPLIER * 10 : 0)), + contextUtilization: Math.min(100, architecture.contextWindow * 0.4), + }; + + if (researchBonuses) { + capabilities.reasoning = clamp(capabilities.reasoning + researchBonuses.reasoningBonus); + capabilities.coding = clamp(capabilities.coding + researchBonuses.codingBonus); + capabilities.creative = clamp(capabilities.creative + researchBonuses.creativeBonus); + capabilities.multimodal = clamp(capabilities.multimodal + researchBonuses.multimodalBonus); + capabilities.agents = clamp(capabilities.agents + researchBonuses.agentsBonus); + } + + const breakthroughBonuses: Partial> = {}; + for (const event of pipeline.events) { + if ((event.type === 'breakthrough' || event.type === 'emergent_capability') && event.impact.capabilityDomain && event.impact.capabilityBonus) { + const domain = event.impact.capabilityDomain; + breakthroughBonuses[domain] = (breakthroughBonuses[domain] ?? 0) + event.impact.capabilityBonus; + } + } + for (const [domain, bonus] of Object.entries(breakthroughBonuses)) { + const key = domain as keyof ModelCapabilities; + capabilities[key] = clamp(capabilities[key] + bonus); } - } - for (const [domain, bonus] of Object.entries(breakthroughBonuses)) { - const key = domain as keyof ModelCapabilities; - capabilities[key] = clamp(capabilities[key] + bonus); } const completedStages: ('pretraining' | 'sft' | 'alignment')[] = ['pretraining']; - if (pipeline.stages.sft?.isComplete) { + if (pipeline.stages.sft.isComplete) { completedStages.push('sft'); const sft = pipeline.stages.sft; - for (let i = 0; i < sft.specializations.length; i++) { - const spec = sft.specializations[i]; - const bonuses = SFT_SPECIALIZATION_BONUSES[spec]; - if (!bonuses) continue; - const diminishing = i === 0 ? 1.0 : i === 1 ? 0.7 : 0.4; - for (const [cap, value] of Object.entries(bonuses)) { - const key = cap as keyof ModelCapabilities; - capabilities[key] = clamp(capabilities[key] + value * diminishing); + if (!sourceModel) { + for (let i = 0; i < sft.specializations.length; i++) { + const spec = sft.specializations[i]; + const bonuses = SFT_SPECIALIZATION_BONUSES[spec]; + if (!bonuses) continue; + const diminishing = i === 0 ? 1.0 : i === 1 ? 0.7 : 0.4; + for (const [cap, value] of Object.entries(bonuses)) { + const key = cap as keyof ModelCapabilities; + capabilities[key] = clamp(capabilities[key] + value * diminishing); + } } } } @@ -401,7 +378,7 @@ function createBaseModel( let overallSafety = Math.min(100, 30 + safetyResearchBonus + Math.random() * 10); let refusalRate = overallSafety > 60 ? 0.1 : 0.03; - if (pipeline.stages.alignment?.isComplete) { + if (pipeline.stages.alignment.isComplete) { completedStages.push('alignment'); const alignment = pipeline.stages.alignment; const methodConfig = ALIGNMENT_METHODS[alignment.method]; @@ -409,10 +386,12 @@ function createBaseModel( const safetyGain = methodConfig.safetyGain * alignment.safetyWeight; overallSafety = Math.min(100, overallSafety + safetyGain); refusalRate = methodConfig.baseRefusal * Math.pow(alignment.safetyWeight, 1.5); - const capLoss = methodConfig.capabilityLoss * alignment.safetyWeight * 0.5; - for (const key of Object.keys(capabilities) as (keyof ModelCapabilities)[]) { - if (key !== 'speed' && key !== 'contextUtilization') { - capabilities[key] = clamp(capabilities[key] - capLoss); + if (!sourceModel) { + const capLoss = methodConfig.capabilityLoss * alignment.safetyWeight * 0.5; + for (const key of Object.keys(capabilities) as (keyof ModelCapabilities)[]) { + if (key !== 'speed' && key !== 'contextUtilization') { + capabilities[key] = clamp(capabilities[key] - capLoss); + } } } } @@ -426,10 +405,15 @@ function createBaseModel( honesty: overallSafety * 0.9, }; + const family = state.models.families.find(f => f.id === pipeline.familyId); + const version = sourceModel ? Math.round((sourceModel.version + 0.1) * 10) / 10 : 1.0; + const familyName = family?.name ?? pipeline.modelName; + const autoName = `${familyName} ${SIZE_TIER_LABELS[pipeline.sizeTier]} v${version.toFixed(1)}`; + return { id: uuid(), familyId: pipeline.familyId, - name: pipeline.modelName, + name: autoName, architecture, dataMix, capabilities, @@ -439,6 +423,10 @@ function createBaseModel( trainedAtTick: state.meta.tickCount, trainingCostTotal: compute, trainingStagesCompleted: completedStages, + sizeTier: pipeline.sizeTier, + version, + sftSpecializations: pipeline.stages.sft.specializations, + alignmentMethod: pipeline.stages.alignment.method, }; } @@ -467,59 +455,30 @@ function createVariant(job: VariantCreationJob, base: BaseModel): ModelVariant { const caps = { ...base.capabilities }; let costMultiplier = 1.0; let speedMultiplier = 1.0; - let variantName = base.name; - let arch = { ...base.architecture }; - if (job.jobType === 'distillation' && 'targetParameters' in job.config) { - const config = job.config; - const sizeRatio = config.targetParameters / base.architecture.totalParameters; - const retention = DISTILLATION_BASE_RETENTION + sizeRatio * 0.25; + const config = job.config; + const qConfig = QUANTIZATION_CONFIGS[config.level]; + if (qConfig) { for (const key of Object.keys(caps) as (keyof ModelCapabilities)[]) { - caps[key] = clamp(caps[key] * retention); + if (key !== 'speed') caps[key] = clamp(caps[key] * qConfig.qualityRetention); } - costMultiplier = sizeRatio * 0.8; - speedMultiplier = (1 / sizeRatio) * 0.7; - arch = { ...arch, totalParameters: config.targetParameters, activeParameters: config.targetParameters }; - variantName = config.variantName; - } else if (job.jobType === 'fine-tuning' && 'specialization' in job.config) { - const config = job.config; - const bonuses = SFT_SPECIALIZATION_BONUSES[config.specialization]; - if (bonuses) { - for (const [cap, value] of Object.entries(bonuses)) { - caps[cap as keyof ModelCapabilities] = clamp(caps[cap as keyof ModelCapabilities] + value); - } - } - variantName = config.variantName; - } else if (job.jobType === 'quantization' && 'level' in job.config) { - const config = job.config; - const qConfig = QUANTIZATION_CONFIGS[config.level]; - if (qConfig) { - for (const key of Object.keys(caps) as (keyof ModelCapabilities)[]) { - if (key !== 'speed') caps[key] = clamp(caps[key] * qConfig.qualityRetention); - } - caps.speed = clamp(caps.speed * qConfig.speedMultiplier); - costMultiplier = qConfig.costMultiplier; - speedMultiplier = qConfig.speedMultiplier; - } - variantName = config.variantName; + caps.speed = clamp(caps.speed * qConfig.speedMultiplier); + costMultiplier = qConfig.costMultiplier; + speedMultiplier = qConfig.speedMultiplier; } return { id: uuid(), familyId: base.familyId, baseModelId: base.id, - name: variantName, - variantType: job.jobType === 'distillation' ? 'distilled' : job.jobType === 'fine-tuning' ? 'fine-tuned' : 'quantized', - architecture: arch, + name: config.variantName, + variantType: 'quantized', + architecture: { ...base.architecture }, capabilities: caps, safetyProfile: { ...base.safetyProfile }, isDeployed: false, createdAtTick: 0, - quantization: job.jobType === 'quantization' && 'level' in job.config ? job.config.level : undefined, - distillationRetention: job.jobType === 'distillation' && 'targetParameters' in job.config - ? DISTILLATION_BASE_RETENTION + (job.config.targetParameters / base.architecture.totalParameters) * 0.25 - : undefined, - finetuneSpecialization: job.jobType === 'fine-tuning' && 'specialization' in job.config ? job.config.specialization : undefined, + quantization: config.level, costMultiplier, speedMultiplier, }; diff --git a/packages/shared/src/constants/gameBalance.ts b/packages/shared/src/constants/gameBalance.ts index 25dac65..6c7b643 100644 --- a/packages/shared/src/constants/gameBalance.ts +++ b/packages/shared/src/constants/gameBalance.ts @@ -1,6 +1,7 @@ import type { DCTier, DCTierConfig, RackSkuId, RackSkuConfig, SwitchTier, SwitchTierConfig, CampusTierCost, ClusterCostConfig, CoolingType, CoolingTypeConfig, NetworkFabric, NetworkFabricConfig } from '../types/infrastructure'; import type { Era } from '../types/gameState'; import type { ConsumerTierId, ApiTierId, SeasonalPhase, EnterprisePipelineStage, EnterpriseSegment, TAMSegmentId } from '../types/market'; +import type { SizeTier } from '../types/models'; export const TICK_INTERVAL_MS = 1000; export const MAX_OFFLINE_TICKS = 86_400; @@ -34,13 +35,24 @@ export const MAX_CONCURRENT_TRAINING: Record = { startup: 1, scaleup: 2, bigtech: 4, agi: 8, }; -export const DISTILLATION_COMPUTE_FRACTION = 0.15; -export const DISTILLATION_TIME_FRACTION = 0.20; -export const DISTILLATION_BASE_RETENTION = 0.70; -export const FINETUNE_COMPUTE_FRACTION = 0.03; -export const FINETUNE_TIME_FRACTION = 0.08; export const QUANTIZATION_TICKS = 8; +export const SIZE_TIER_MAP: Record = { + 1: 'nano', 3: 'nano', + 7: 'small', 13: 'small', + 30: 'medium', 70: 'medium', + 130: 'large', 300: 'large', + 700: 'flagship', 1400: 'flagship', +}; + +export const SIZE_TIER_LABELS: Record = { + nano: 'Nano', small: 'Small', medium: 'Medium', large: 'Large', flagship: 'Flagship', +}; + +export const POINT_RELEASE_TIME_FRACTION = 0.40; +export const POINT_RELEASE_CAPABILITY_GAIN = 0.08; +export const POINT_RELEASE_MAX_VERSION = 9; + export const MOE_CAPABILITY_MULTIPLIER = 1.15; export const MOE_SPEED_MULTIPLIER = 1.3; export const PARAMETER_OPTIONS = [1, 3, 7, 13, 30, 70, 130, 300, 700, 1400]; diff --git a/packages/shared/src/types/gameState.ts b/packages/shared/src/types/gameState.ts index 6c3451b..bbbc242 100644 --- a/packages/shared/src/types/gameState.ts +++ b/packages/shared/src/types/gameState.ts @@ -52,4 +52,4 @@ export const INITIAL_SETTINGS: GameSettings = { musicVolume: 0.5, }; -export const SAVE_VERSION = 7; +export const SAVE_VERSION = 8; diff --git a/packages/shared/src/types/models.ts b/packages/shared/src/types/models.ts index 5631e55..afa5ee8 100644 --- a/packages/shared/src/types/models.ts +++ b/packages/shared/src/types/models.ts @@ -2,6 +2,7 @@ import type { Era } from './gameState'; import type { DataDomain } from './data'; export type ArchitectureType = 'dense' | 'moe'; +export type SizeTier = 'nano' | 'small' | 'medium' | 'large' | 'flagship'; export interface ModelArchitecture { type: ArchitectureType; @@ -27,13 +28,16 @@ export interface TrainingPipeline { currentStage: TrainingStage; stages: { pretraining: PreTrainingConfig; - sft: SFTConfig | null; - alignment: AlignmentConfig | null; + sft: SFTConfig; + alignment: AlignmentConfig; }; status: TrainingJobStatus; allocatedComputeFraction: number; events: TrainingEvent[]; startedAtTick: number; + sizeTier: SizeTier; + isPointRelease: boolean; + sourceModelId: string | null; } export interface PreTrainingConfig { @@ -125,9 +129,13 @@ export interface BaseModel { trainedAtTick: number; trainingCostTotal: number; trainingStagesCompleted: TrainingStage[]; + sizeTier: SizeTier; + version: number; + sftSpecializations: SFTSpecialization[]; + alignmentMethod: AlignmentMethod | null; } -export type VariantType = 'distilled' | 'fine-tuned' | 'quantized'; +export type VariantType = 'quantized'; export type QuantizationLevel = 'fp16' | 'int8' | 'int4' | 'int2'; export interface ModelVariant { @@ -142,8 +150,6 @@ export interface ModelVariant { isDeployed: boolean; createdAtTick: number; quantization?: QuantizationLevel; - distillationRetention?: number; - finetuneSpecialization?: SFTSpecialization; costMultiplier: number; speedMultiplier: number; } @@ -152,37 +158,25 @@ export interface ModelFamily { id: string; name: string; generation: number; - baseModelId: string | null; + baseModelIds: string[]; variants: ModelVariant[]; createdAtTick: number; } -export type VariantJobType = 'distillation' | 'fine-tuning' | 'quantization'; +export type VariantJobType = 'quantization'; export interface VariantCreationJob { id: string; familyId: string; baseModelId: string; jobType: VariantJobType; - config: DistillationConfig | FineTuneConfig | QuantizationConfig; + config: QuantizationConfig; progressTicks: number; totalTicks: number; allocatedComputeFraction: number; status: 'active' | 'completed'; } -export interface DistillationConfig { - targetParameters: number; - targetArchitecture: ArchitectureType; - variantName: string; -} - -export interface FineTuneConfig { - specialization: SFTSpecialization; - datasetIds: string[]; - variantName: string; -} - export interface QuantizationConfig { level: QuantizationLevel; variantName: string;