BmnRoot
Loading...
Searching...
No Matches
train.py
Go to the documentation of this file.
1#/// script
2#dependencies = [
3#"torch==2.6.0",
4#"numpy"
5#]
6#
7#[tool.uv.sources]
8#torch = {index = "pytorch" }
9#
10#[[tool.uv.index]]
11#name = "pytorch"
12#url = "https://download.pytorch.org/whl/cu124"
13#///
14
15"""
16 BM@N alignment routine
17 BM@N experiment at NICA complex, JINR, 2025
18
19 Department: Math & Soft Group of HEP lab
20 Author: Igor Polev, polev@jinr.ru
21
22 MSE minimization of alignment problem using NN training techniques.
23 For each alignable detector element we prepare a custom single-layer NN model
24 and train it to get trained NN weights as a solution for alignment parameters.
25 Custom layer consists of measuremets model for defined experiment setup.
26 Datasets are prepared by expdataNN.C from original hits and tracks
27 alignment data in form of separate CSV files for each detector element:
28
29 inputs (5): z, xA, xB, yA, yB — columns [2..6]
30 outputs (2): x, y — columns [0..1]
31
32 where xA, xB, yA, yB are coefficients of straight-line fit:
33
34 x = xA + xB*z, y = yA + yB*z
35
36 This fit is computed from non-aligned data using ordinary linear regression
37 during data export procedure (see expdataNN.C).
38"""
39
40import argparse
41import glob
42import json
43import os
44import re
45import sys
46
47import numpy as np
48import torch
49import torch.nn as nn
50from torch.utils.data import DataLoader, TensorDataset
51
52# ── defaults ──────────────────────────────────────────────────────────────────
53DEFAULT_GLOB = "/home/igor/DATA/bmn/nn/data_det*.csv"
54DEFAULT_EPOCHS = 100
55DEFAULT_LR = 1e-3
56DEFAULT_BATCH = 256
57DEFAULT_OUT_DIR = "models"
58
59# ── dataset layout ────────────────────────────────────────────────────────────
60iX, iY, iZ = 0, 1, 2
61iXA, iXB, iYA, iYB = 3, 4, 5, 6
62
63INPUT_COLS = [iZ, iXA, iYA, iXB, iYB]
64OUTPUT_COLS = [iX, iY]
65
66# ── NN definition ─────────────────────────────────────────────────────────────
67#
68#Custom layer to mimic alignment task as a neural network
69#
70#Input : z, xA, xB, yA, yB(shape : [ batch, 5 ])
71#Output : x, y(shape : [ batch, 2 ])
72#
73#NN parameters = alignment parameners:
74#shift = [dx, dy, dz]
75#
76#layers math is measurements model accounting for (mis) alignment params
77
78class AlignLayer(nn.Module):
79 def __init__(self):
80 super().__init__()
81 self.shiftV = nn.Parameter(torch.zeros(2))
82 self.shiftZ = nn.Parameter(torch.zeros(1))
83
84 def forward(self, batch: torch.Tensor) -> torch.Tensor:
85 A = batch[:, [iXA - 2, iYA - 2]]
86 B = batch[:, [iXB - 2, iYB - 2]]
87 z = batch[:, [iZ - 2]] - self.shiftZ
88
89 return A + B * z + self.shiftV
90
91def load_dataset(path: str) -> TensorDataset:
92 raw = np.loadtxt(path, delimiter=",", skiprows=1, dtype=np.float32)
93 return TensorDataset(
94 torch.from_numpy(raw[:, INPUT_COLS]),
95 torch.from_numpy(raw[:, OUTPUT_COLS])
96 )
97
99 dataset: TensorDataset,
100 det_id: str,
101 epochs: int,
102 lr: float,
103 batch: int,
104 out_dir: str,
105 device: torch.device,
106) -> tuple[float, float, float]:
107 """Train and return (dx, dy, dz) weights."""
108 loader = DataLoader(dataset, batch_size=batch, shuffle=True)
109 model = AlignLayer().to(device)
110 opt = torch.optim.Adam(model.parameters(), lr=lr)
111 loss_fn = nn.MSELoss()
112
113 for epoch in range(1, epochs + 1):
114 model.train()
115 total_loss = 0.0
116 for X_batch, Y_batch in loader:
117 X_batch = X_batch.to(device)
118 Y_batch = Y_batch.to(device)
119 opt.zero_grad()
120 loss = loss_fn(model(X_batch), Y_batch)
121 loss.backward()
122 opt.step()
123 total_loss += loss.item() * len(X_batch)
124
125 if epoch % max(1, epochs // 10) == 0 or epoch == epochs:
126 mse = total_loss / len(dataset)
127 with torch.no_grad():
128 dx, dy = model.shiftV.cpu().tolist()
129 dz = model.shiftZ.cpu().item()
130 print(f" [{det_id}] epoch {epoch:>{len(str(epochs))}}/{epochs} MSE={mse:.6e} dx={dx:+.6f} dy={dy:+.6f} dz={dz:+.6f}")
131
132 os.makedirs(out_dir, exist_ok=True)
133 save_path = os.path.join(out_dir, f"net_{det_id}.pt")
134 torch.save(model.state_dict(), save_path)
135 print(f" [{det_id}] saved -> {save_path}")
136
137 with torch.no_grad():
138 dx, dy = model.shiftV.cpu().tolist()
139 dz = model.shiftZ.cpu().item()
140 return dx, dy, dz
141
142def main():
143 parser = argparse.ArgumentParser(
144 description="Train one linear NN per detector CSV file."
145 )
146 parser.add_argument("--data", default=DEFAULT_GLOB, help="glob pattern for CSV files")
147 parser.add_argument("--epochs", default=DEFAULT_EPOCHS, type=int)
148 parser.add_argument("--lr", default=DEFAULT_LR, type=float)
149 parser.add_argument("--batch", default=DEFAULT_BATCH, type=int)
150 parser.add_argument("--outdir", default=DEFAULT_OUT_DIR, help="directory to save trained models")
151 parser.add_argument("--solution", default=None, help="path to save solution JSON (default: <outdir>/solution.json)")
152 args = parser.parse_args()
153
154 csv_files = sorted(glob.glob(args.data))
155 if not csv_files:
156 print(f"No CSV files matched: {args.data}", file=sys.stderr)
157 sys.exit(1)
158
159 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
160 print(f"Device: {device}")
161 print(f"Found {len(csv_files)} detector file(s)\n")
162
163 corrections = []
164 for path in csv_files:
165 stem = os.path.splitext(os.path.basename(path))[0] # e.g. "data_det12345"
166 m = re.search(r"(\d+)$", stem)
167 det_id = int(m.group(1)) if m else stem
168 print(f"==> {path} (detector {det_id})")
169 dataset = load_dataset(path)
170 print(f" samples: {len(dataset)}")
171 dx, dy, dz = train_one(dataset, stem, args.epochs, args.lr, args.batch, args.outdir, device)
172 corrections.append({"detector_id": det_id, "values": [dx, dy, dz]})
173 print()
174
175 solution = {
176 "CorrectionValues": ["dX", "dY", "dZ"],
177 "CorrectionsPerDetector": corrections,
178 "DetectorElements": len(corrections),
179 }
180 solution_path = args.solution or os.path.join(args.outdir, "solution.json")
181 os.makedirs(os.path.dirname(os.path.abspath(solution_path)), exist_ok=True)
182 with open(solution_path, "w") as f:
183 json.dump(solution, f, indent=4)
184 print(f"Solution saved -> {solution_path}")
185 print("All done.")
186
187if __name__ == "__main__":
188 main()
friend F32vec4 max(const F32vec4 &a, const F32vec4 &b)
Definition P4_F32vec4.h:31
__init__(self)
Definition train.py:79
torch.Tensor forward(self, torch.Tensor batch)
Definition train.py:84
main()
Definition train.py:142
TensorDataset load_dataset(str path)
Definition train.py:91
tuple[float, float, float] train_one(TensorDataset dataset, str det_id, int epochs, float lr, int batch, str out_dir, torch.device device)
Definition train.py:106