61iXA, iXB, iYA, iYB = 3, 4, 5, 6
63INPUT_COLS = [iZ, iXA, iYA, iXB, iYB]
92 raw = np.loadtxt(path, delimiter=
",", skiprows=1, dtype=np.float32)
94 torch.from_numpy(raw[:, INPUT_COLS]),
95 torch.from_numpy(raw[:, OUTPUT_COLS])
99 dataset: TensorDataset,
105 device: torch.device,
106) -> tuple[float, float, float]:
107 """Train and return (dx, dy, dz) weights."""
108 loader = DataLoader(dataset, batch_size=batch, shuffle=
True)
110 opt = torch.optim.Adam(model.parameters(), lr=lr)
111 loss_fn = nn.MSELoss()
113 for epoch
in range(1, epochs + 1):
116 for X_batch, Y_batch
in loader:
117 X_batch = X_batch.to(device)
118 Y_batch = Y_batch.to(device)
120 loss = loss_fn(model(X_batch), Y_batch)
123 total_loss += loss.item() * len(X_batch)
125 if epoch %
max(1, epochs // 10) == 0
or epoch == epochs:
126 mse = total_loss / len(dataset)
127 with torch.no_grad():
128 dx, dy = model.shiftV.cpu().tolist()
129 dz = model.shiftZ.cpu().item()
130 print(f
" [{det_id}] epoch {epoch:>{len(str(epochs))}}/{epochs} MSE={mse:.6e} dx={dx:+.6f} dy={dy:+.6f} dz={dz:+.6f}")
132 os.makedirs(out_dir, exist_ok=
True)
133 save_path = os.path.join(out_dir, f
"net_{det_id}.pt")
134 torch.save(model.state_dict(), save_path)
135 print(f
" [{det_id}] saved -> {save_path}")
137 with torch.no_grad():
138 dx, dy = model.shiftV.cpu().tolist()
139 dz = model.shiftZ.cpu().item()
143 parser = argparse.ArgumentParser(
144 description=
"Train one linear NN per detector CSV file."
146 parser.add_argument(
"--data", default=DEFAULT_GLOB, help=
"glob pattern for CSV files")
147 parser.add_argument(
"--epochs", default=DEFAULT_EPOCHS, type=int)
148 parser.add_argument(
"--lr", default=DEFAULT_LR, type=float)
149 parser.add_argument(
"--batch", default=DEFAULT_BATCH, type=int)
150 parser.add_argument(
"--outdir", default=DEFAULT_OUT_DIR, help=
"directory to save trained models")
151 parser.add_argument(
"--solution", default=
None, help=
"path to save solution JSON (default: <outdir>/solution.json)")
152 args = parser.parse_args()
154 csv_files = sorted(glob.glob(args.data))
156 print(f
"No CSV files matched: {args.data}", file=sys.stderr)
159 device = torch.device(
"cuda" if torch.cuda.is_available()
else "cpu")
160 print(f
"Device: {device}")
161 print(f
"Found {len(csv_files)} detector file(s)\n")
164 for path
in csv_files:
165 stem = os.path.splitext(os.path.basename(path))[0]
166 m = re.search(
r"(\d+)$", stem)
167 det_id = int(m.group(1))
if m
else stem
168 print(f
"==> {path} (detector {det_id})")
170 print(f
" samples: {len(dataset)}")
171 dx, dy, dz =
train_one(dataset, stem, args.epochs, args.lr, args.batch, args.outdir, device)
172 corrections.append({
"detector_id": det_id,
"values": [dx, dy, dz]})
176 "CorrectionValues": [
"dX",
"dY",
"dZ"],
177 "CorrectionsPerDetector": corrections,
178 "DetectorElements": len(corrections),
180 solution_path = args.solution
or os.path.join(args.outdir,
"solution.json")
181 os.makedirs(os.path.dirname(os.path.abspath(solution_path)), exist_ok=
True)
182 with open(solution_path,
"w")
as f:
183 json.dump(solution, f, indent=4)
184 print(f
"Solution saved -> {solution_path}")
tuple[float, float, float] train_one(TensorDataset dataset, str det_id, int epochs, float lr, int batch, str out_dir, torch.device device)