Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ jobs:
with:
cvmfs_repositories: 'oasis.opensciencegrid.org'
- name: unit tests
run: ./build-coatjava.sh --cvmfs --unittests --no-progress -T${{ env.nthreads }}
run: ./build-coatjava.sh --lfs --unittests --no-progress -T${{ env.nthreads }}
- name: collect jacoco report
if: ${{ matrix.JAVA_VERSION == env.JAVA_VERSION }}
run: validation/jacoco-aggregate.sh
Expand Down
2 changes: 1 addition & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ download:
dependencies: [build]
script:
- tar -xzf coatjava.tar.gz
- ./build-coatjava.sh -T$JL_RUNNER_AVAIL_CPU --unittests --quiet --no-progress
- ./build-coatjava.sh -T$JL_RUNNER_AVAIL_CPU --lfs --unittests --quiet --no-progress
- ./validation/jacoco-aggregate.sh
artifacts:
when: always
Expand Down
24 changes: 23 additions & 1 deletion etc/bankdefs/hipo4/alert.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,18 @@
"info": "path length inside atof wedge in mm"
}
]
},{
},
{
"name": "ALERT::ai:projections",
"group": 23000,
"item": 32,
"info": "Track Projections to ATOF given by AI",
"entries": [
{"name": "trackid", "type": "I", "info": "track id"},
{"name": "matched_atof_hit_id", "type": "I", "info": "id of the matched ATOF hit, -1 if no hit was matched"}
]
},
{
"name": "ATOF::hits",
"group": 22500,
"item": 21,
Expand Down Expand Up @@ -414,5 +425,16 @@
{"name": "y5", "type": "F", "info": "Y5 position of the 5th superprecluster (mm)"},
{"name": "pred", "type": "F", "info": "Prediction of the model: 0 mean bad track; 1 mean good track"}
]
},
{
"name": "AHDC::interclusters",
"group": 23000,
"item": 27,
"info": "InterClusters info",
"entries": [
{"name": "trackid", "type": "I", "info": "track id"},
{"name": "x", "type": "F", "info": "x info (mm)"},
{"name": "y", "type": "F", "info": "y info (mm)"}
]
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,34 @@

import java.util.ArrayList;

import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;

import java.io.IOException;

public class AIPrediction {


public AIPrediction() throws ModelNotFoundException, MalformedModelException, IOException {
}
public AIPrediction() {}

public ArrayList<TrackPrediction> prediction(ArrayList<ArrayList<PreclusterSuperlayer>> tracks, ZooModel<float[], Float> model) throws TranslateException {
public ArrayList<TrackPrediction> prediction(ArrayList<ArrayList<InterCluster>> tracks, ModelTrackFinding modelTrackFinding) throws Exception {
ArrayList<TrackPrediction> result = new ArrayList<>();
for (ArrayList<PreclusterSuperlayer> track : tracks) {
float[] a = new float[]{(float) track.get(0).getX(), (float) track.get(0).getY(),
(float) track.get(1).getX(), (float) track.get(1).getY(),
(float) track.get(2).getX(), (float) track.get(2).getY(),
(float) track.get(3).getX(), (float) track.get(3).getY(),
(float) track.get(4).getX(), (float) track.get(4).getY(),
};

Predictor<float[], Float> my_predictor = model.newPredictor();
result.add(new TrackPrediction(my_predictor.predict(a), track));

if (tracks.isEmpty()) return result;

float[][] batchInput = new float[tracks.size()][10];
for (int i = 0; i < tracks.size(); i++) {
ArrayList<InterCluster> track = tracks.get(i);
batchInput[i][0] = (float) track.get(0).getX();
batchInput[i][1] = (float) track.get(0).getY();
batchInput[i][2] = (float) track.get(1).getX();
batchInput[i][3] = (float) track.get(1).getY();
batchInput[i][4] = (float) track.get(2).getX();
batchInput[i][5] = (float) track.get(2).getY();
batchInput[i][6] = (float) track.get(3).getX();
batchInput[i][7] = (float) track.get(3).getY();
batchInput[i][8] = (float) track.get(4).getX();
batchInput[i][9] = (float) track.get(4).getY();
}

float[] predictions = modelTrackFinding.batchPredict(batchInput);
for (int i = 0; i < tracks.size(); i++) {
result.add(new TrackPrediction(predictions[i], tracks.get(i)));
}

return result;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
package org.jlab.rec.ahdc.AI;

import org.jlab.rec.ahdc.Hit.Hit;
import org.jlab.rec.ahdc.PreCluster.PreCluster;

import java.util.ArrayList;

public class PreclusterSuperlayer {
public class InterCluster {
private int trackId = -1;
private final double x;
private final double y;
private ArrayList<PreCluster> preclusters = new ArrayList<>();


; public PreclusterSuperlayer(ArrayList<PreCluster> preclusters_) {
public InterCluster(ArrayList<PreCluster> preclusters_) {
this.preclusters = preclusters_;
double x_ = 0;
double y_ = 0;
Expand Down Expand Up @@ -43,6 +42,13 @@ public int getSuperlayer() {
return this.preclusters.get(0).get_Super_layer();
}

public int getTrackId() {
return trackId;
}

public void setTrackId(int trackId) {
this.trackId = trackId;
}

public String toString() {
return "PreCluster{" + "X: " + this.x + " Y: " + this.y + " phi: " + Math.atan2(this.y, this.x) + "}\n";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,28 @@
*
* \todo fix class name
*/
public class Model {
private ZooModel<float[], Float> model;
public class ModelTrackFinding {
private final ZooModel<float[], Float> model;

public Model() {
Translator<float[], Float> my_translator = new Translator<float[], Float>() {
public ModelTrackFinding() {
Translator<float[], Float> my_translator = new Translator<>() {
@Override
public Float processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
return ndList.get(0).getFloat();
}

@Override
public NDList processInput(TranslatorContext translatorContext, float[] floats) throws Exception {
NDManager manager = NDManager.newBaseManager();
NDArray samples = manager.zeros(new Shape(floats.length));
samples.set(floats);
NDManager manager = translatorContext.getNDManager();
NDArray samples = manager.create(floats);
return new NDList(samples);
}
};
System.setProperty("ai.djl.pytorch.num_interop_threads", "1");
System.setProperty("ai.djl.pytorch.num_threads", "1");
System.setProperty("ai.djl.pytorch.graph_optimizer", "false");

String path = CLASResources.getResourcePath("etc/data/nnet/ALERT/model_AHDC/");
String path = CLASResources.getResourcePath("etc/data/nnet/rg-l/model_AHDC/");
Criteria<float[], Float> my_model = Criteria.builder().setTypes(float[].class, Float.class)
.optModelPath(Paths.get(path))
.optEngine("PyTorch")
Expand All @@ -63,4 +62,41 @@ public NDList processInput(TranslatorContext translatorContext, float[] floats)
public ZooModel<float[], Float> getModel() {
return model;
}

/**
* Batch prediction for improved performance.
* Predicts all tracks at once instead of one at a time.
* This is significantly faster due to reduced overhead and better GPU utilization.
*
* @param inputs Array of input features for each track
* @return Array of predictions for each track
*/
public float[] batchPredict(float[][] inputs) throws Exception {
if (inputs == null || inputs.length == 0) {
return new float[0];
}

try (NDManager manager = NDManager.newBaseManager()) {
int batchSize = inputs.length;
NDArray batchInput = manager.create(inputs);
NDList inputList = new NDList(batchInput);
ai.djl.inference.Predictor<NDList, NDList> rawPredictor = model.newPredictor(new ai.djl.translate.NoopTranslator());
NDList output = rawPredictor.predict(inputList);

NDArray outputArray = output.get(0);
float[] results = new float[batchSize];

if (outputArray.getShape().dimension() == 2) {
for (int i = 0; i < batchSize; i++) {
results[i] = outputArray.get(i, 0).getFloat();
}
} else {
for (int i = 0; i < batchSize; i++) {
results[i] = outputArray.get(i).getFloat();
}
}

return results;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,119 +3,32 @@
import org.jlab.rec.ahdc.Hit.Hit;
import org.jlab.rec.ahdc.PreCluster.PreCluster;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.*;

public class PreClustering {
static final double DISTANCE_MAX = 8.0;

private ArrayList<Hit> fill(List<Hit> hits, int super_layer, int layer) {

ArrayList<Hit> result = new ArrayList<>();
for (Hit hit : hits) {
if (hit.getSuperLayerId() == super_layer && hit.getLayerId() == layer) result.add(hit);
}
return result;
}

public ArrayList<PreCluster> find_preclusters_for_AI(List<Hit> AHDC_hits) {
ArrayList<PreCluster> preclusters = new ArrayList<>();

ArrayList<Hit> s1l1 = fill(AHDC_hits, 1, 1);
ArrayList<Hit> s2l1 = fill(AHDC_hits, 2, 1);
ArrayList<Hit> s2l2 = fill(AHDC_hits, 2, 2);
ArrayList<Hit> s3l1 = fill(AHDC_hits, 3, 1);
ArrayList<Hit> s3l2 = fill(AHDC_hits, 3, 2);
ArrayList<Hit> s4l1 = fill(AHDC_hits, 4, 1);
ArrayList<Hit> s4l2 = fill(AHDC_hits, 4, 2);
ArrayList<Hit> s5l1 = fill(AHDC_hits, 5, 1);

// Sort hits of each layers by phi:
Comparator<Hit> comparator = new Comparator<>() {
@Override
public int compare(Hit a1, Hit a2) {
return Double.compare(a1.getPhi(), a2.getPhi());
}
};

s1l1.sort(comparator);
s2l1.sort(comparator);
s2l2.sort(comparator);
s3l1.sort(comparator);
s3l2.sort(comparator);
s4l1.sort(comparator);
s4l2.sort(comparator);
s5l1.sort(comparator);

ArrayList<ArrayList<Hit>> all_super_layer = new ArrayList<>(Arrays.asList(s1l1, s2l1, s2l2, s3l1, s3l2, s4l1, s4l2, s5l1));

for (ArrayList<Hit> p : all_super_layer) {
for (Hit hit : p) {
hit.setUse(false);
}
}

for (ArrayList<Hit> p : all_super_layer) {
for (Hit hit : p) {
if (hit.is_NoUsed()) {
ArrayList<Hit> temp = new ArrayList<>();
temp.add(hit);
hit.setUse(true);
int expected_wire_plus = hit.getWireId() + 1;
int expected_wire_minus = hit.getWireId() - 1;
if (hit.getWireId() == 1)
expected_wire_minus = hit.getNbOfWires();
if (hit.getWireId() == hit.getNbOfWires() )
expected_wire_plus = 1;


boolean has_next = true;
while (has_next) {
has_next = false;
for (Hit hit1 : p) {
if (hit1.is_NoUsed() && (hit1.getWireId() == expected_wire_minus || hit1.getWireId() == expected_wire_plus)) {
temp.add(hit1);
hit1.setUse(true);
has_next = true;
break;
}
}
}
if (!temp.isEmpty()) preclusters.add(new PreCluster(temp));
}
}
}
return preclusters;
}

public ArrayList<PreclusterSuperlayer> merge_preclusters(ArrayList<PreCluster> preclusters) {
double distance_max = 8.0;

ArrayList<PreclusterSuperlayer> superpreclusters = new ArrayList<>();
public ArrayList<InterCluster> mergePreclusters(ArrayList<PreCluster> preclusters) {
ArrayList<InterCluster> interclusters = new ArrayList<>();
for (PreCluster precluster : preclusters) {
if (!precluster.is_Used()) {
ArrayList<PreCluster> tmp = new ArrayList<>();
tmp.add(precluster);
precluster.set_Used(true);
for (PreCluster other : preclusters) {
if (precluster.get_hits_list().get(precluster.get_hits_list().size() - 1).getSuperLayerId() == other.get_hits_list().get(other.get_hits_list().size() - 1).getSuperLayerId() && precluster.get_hits_list().get(precluster.get_hits_list().size() - 1).getLayerId() != other.get_hits_list().get(other.get_hits_list().size() - 1).getLayerId() && !other.is_Used()) {
double dx = precluster.get_X() - other.get_X();
double dy = precluster.get_Y() - other.get_Y();
double distance = Math.sqrt(dx * dx + dy * dy);

if (distance < distance_max) {
if (precluster.get_hits_list().getLast().getSuperLayerId() == other.get_hits_list().getLast().getSuperLayerId()
&& precluster.get_hits_list().getLast().getLayerId() != other.get_hits_list().getLast().getLayerId()
&& !other.is_Used()) {
if (Math.hypot(precluster.get_X() - other.get_X(), precluster.get_Y() - other.get_Y()) < DISTANCE_MAX) {
other.set_Used(true);
tmp.add(other);
}
}
}

if (!tmp.isEmpty()) superpreclusters.add(new PreclusterSuperlayer(tmp));
if (!tmp.isEmpty()) interclusters.add(new InterCluster(tmp));
}
}

return superpreclusters;
return interclusters;
}


Expand Down
Loading