BmnRoot
Loading...
Searching...
No Matches
draw_tracks.py
Go to the documentation of this file.
1#/// script
2#dependencies = [
3#"numpy",
4#"matplotlib"
5#]
6#///
7
8""
9 "
10 Draw a random subset of 3D straight
11 - line tracks from a detector CSV file.
12
13 Each row encodes a track as x = xA + xB * z,
14 y = yA
15 + yB
16 * z with columns : 0
17 : x 1
18 : y 2
19 : z 3
20 : xA 4
21 : xB 5
22 : yA 6
23 : yB
24
25 Lines are drawn over the full z range present in the file.The measured hit
26 point(x, y, z) is also marked.
27"""
28
29import argparse
30import sys
31
32import matplotlib.pyplot as plt
33import numpy as np
34
35# ── column indices(same as train.py) ─────────────────────────────────────────
36iX, iY, iZ = 0, 1, 2
37iXA, iXB, iYA, iYB = 3, 4, 5, 6
38
39DEFAULT_N = 50
40
41
42def main():
43 parser = argparse.ArgumentParser(
44 description="Draw random 3D tracks from a detector CSV file."
45 )
46 parser.add_argument("csv", help="path to detector CSV file")
47 parser.add_argument("-n", "--count", default=DEFAULT_N, type=int,
48 help=f"number of random tracks to draw (default {DEFAULT_N})")
49 parser.add_argument("--seed", default=None, type=int,
50 help="random seed for reproducibility")
51 args = parser.parse_args()
52
53 data = np.loadtxt(args.csv, delimiter=",", skiprows=1, dtype=np.float64)
54 if data.ndim == 1:
55 data = data[np.newaxis, :]
56
57 rng = np.random.default_rng(args.seed)
58 n = min(args.count, len(data))
59 idx = rng.choice(len(data), size=n, replace=False)
60 rows = data[idx]
61
62 z_max = data[:, iZ].max()
63 z_lo = 0.0
64 z_hi = z_max + 0.1 * max(z_max, 1.0)
65 zz = np.linspace(z_lo, z_hi, 50)
66
67 fig = plt.figure(figsize=(10, 7))
68 ax = fig.add_subplot(111, projection="3d")
69
70 for row in rows:
71 xA, xB = row[iXA], row[iXB]
72 yA, yB = row[iYA], row[iYB]
73 xs = xA + xB * zz
74 ys = yA + yB * zz
75 ax.plot(zz, xs, ys, color="steelblue", alpha=0.4, linewidth=0.8)
76
77#measured hit points
78 ax.scatter(rows[:, iZ], rows[:, iX], rows[:, iY],
79 color="tomato", s=12, zorder=5, label="measured hits")
80
81 ax.set_xlabel("Z")
82 ax.set_ylabel("X")
83 ax.set_zlabel("Y")
84 ax.set_title(f"{n} random tracks — {args.csv}")
85 ax.legend()
86 plt.tight_layout()
87 plt.show()
88
89
90if __name__ == "__main__":
91 main()