diff --git a/src/dlstbx/services/strategy.py b/src/dlstbx/services/strategy.py index cd67c86a7..bd7030ccf 100644 --- a/src/dlstbx/services/strategy.py +++ b/src/dlstbx/services/strategy.py @@ -152,9 +152,11 @@ def generate_strategy( """Generate a strategy from the results of an upstream pipeline""" self.log.info("Received strategy request, generating strategy") + recipe_params = rw.recipe_step["parameters"] parameters = ChainMapWithReplacement( message.get("parameters", {}) if isinstance(message, dict) else {}, - rw.recipe_step["parameters"].get("ispyb_parameters", {}), + recipe_params.get("ispyb_parameters", {}), + recipe_params, substitutions=rw.environment, ) self.log.info(f"Received parameters for strategy generation:\n{parameters}") @@ -177,6 +179,9 @@ def generate_strategy( if isinstance(parameters["resolution"], list) else float(parameters["resolution"]) ) + dc_transmission = ( + float(parameters.get("transmission", 100)) / 100 + ) resolution_offset = 0.5 min_resolution = 0.9 resolution = max((resolution_estimate) - resolution_offset, min_resolution) @@ -193,12 +198,16 @@ def generate_strategy( ) beamline_config = parse_config_file(beamline_config_file) - scaled_transmission = parameters.get("scaled_transmission", 1.0) + recommended_max_transmission = parameters.get("scaled_transmission", 1.0) - transmission_limits = ( - get_beamline_param(beamline_config, ("gda.mx.udc.minTransmission",), 0.0), - min(get_beamline_param(beamline_config, ("gda.mx.udc.maxTransmission",), 1.0), - scaled_transmission) + transmission_limits = (get_beamline_param( + beamline_config, ("gda.mx.udc.minTransmission",), 0.0), + min( + get_beamline_param( + beamline_config, ("gda.mx.udc.maxTransmission",), 1.0 + ), + recommended_max_transmission + ), ) exposure_time_limits = ( get_beamline_param( @@ -213,7 +222,7 @@ def generate_strategy( ), ) - recipes = {"OSC.yaml": "Native", "Ligand binding.yaml": "Ligand"} + recipes = {"OSC.yaml": "Native", "Ligand binding.yaml": "Ligand", "SAD.yaml": "Phasing"} ispyb_command_list = [] for recipe, recipe_alias in recipes.items(): @@ -304,7 +313,7 @@ def generate_strategy( ispyb_command_list.append(d) # Convert transmission to percentage for ISPyB - transmission_pct = transmission * 100 + relative_transmission_pct = transmission / dc_transmission * 100 axis_end = ( rotation_start + rotation_increment * recipe_step.number_of_images @@ -318,10 +327,11 @@ def generate_strategy( "axisstart": rotation_start, "axisend": axis_end, "exposuretime": exposure_time, - "transmission": transmission_pct, + "transmission": relative_transmission_pct, "oscillationrange": rotation_increment, "noimages": recipe_step.number_of_images, "resolution": resolution, + "doseTotal": dose, "ispyb_command": "insert_screening_strategy_sub_wedge", "screening_strategy_wedge_id": f"$ispyb_screening_strategy_wedge_id_{n_step}", "store_result": f"ispyb_screening_strategy_sub_wedge_id_{n_step}", diff --git a/src/dlstbx/services/trigger.py b/src/dlstbx/services/trigger.py index d74ff5d1e..f8f69e07f 100644 --- a/src/dlstbx/services/trigger.py +++ b/src/dlstbx/services/trigger.py @@ -4,6 +4,7 @@ import pathlib import re from datetime import datetime, timedelta +from itertools import chain from time import time from typing import Any, Dict, List, Literal, Mapping, Optional @@ -2880,12 +2881,40 @@ def trigger_strategy( ) return {"success": True} - if parameters.beamline not in ["i03", "i04", "i04-1"]: + if parameters.beamline not in ["i04"]: self.log.info( f"Skipping strategy trigger: beamline {parameters.beamline} not supported" ) return {"success": True} + find_process_program = ( + session.query(AutoProcProgram.processingPrograms) + .join( + ProcessingJob, + AutoProcProgram.processingJobId == ProcessingJob.processingJobId, + ) + .filter(ProcessingJob.dataCollectionId == parameters.dcid) + ) + + curr_program = find_process_program.filter( + AutoProcProgram.autoProcProgramId == parameters.program_id + ).scalar() + # xia2 dials occassionaly gives optimistic estimate for resolution + if curr_program == "xia2 dials": + self.log.info( + f"Skipping strategy trigger for dcid={parameters.dcid} from program: xia2 dials." + ) + return {"success": True} + + udc_strategy_previously_triggered = find_process_program.filter( + AutoProcProgram.processingPrograms == "UDC strategy" + ).all() + if udc_strategy_previously_triggered: + self.log.info( + f"Skipping strategy trigger: UDC Strategy has already been triggered for dcid={parameters.dcid}." + ) + return {"success": True} + # Get resolution estimate from ispyb records for upstream pipeline - returns None if not found. resolution = ( session.query(AutoProcScalingStatistics.resolutionLimitHigh) diff --git a/src/dlstbx/wrapper/estimate_transmission.py b/src/dlstbx/wrapper/estimate_transmission.py new file mode 100644 index 000000000..f0840531d --- /dev/null +++ b/src/dlstbx/wrapper/estimate_transmission.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import json +import math +import shutil +import subprocess +from collections import Counter +from itertools import accumulate +from pathlib import Path + +from dials.array_family import flex + +from dlstbx.wrapper import Wrapper + + +class EstimateTransmissionWrapper(Wrapper): + _logger_name = "dlstbx.wrap.estimate_transmission" + + def run(self): + assert hasattr(self, "recwrap"), "No recipewrapper object found" + + params = self.recwrap.recipe_step["job_parameters"] + working_directory = Path(params["working_directory"]) + results_directory = Path(params["results_directory"]) + + beamline = params["beamline"] + pixel_percentile = params["pixel_percentile"].get(beamline, 100) / 100 + target_countrate_pct = params["target_countrate_pct"].get(beamline, 50) / 100 + transmission = float(params["transmission"]) + file = params["input_file"] + + commands = [ + ("dials.import", ["dials.import", file]), + ( + "dials.find_spots", + ["dials.find_spots", "imported.expt", "ice_rings.filter=True"], + ), + ] + + for command, script in commands: + result = subprocess.run(script, cwd=working_directory, check=True) + + if result.returncode: + self.log.info(f"{command} failed with return code {result.returncode}") + self.log.info(result.stderr) + + self.log.debug(f"Command output:\n{result.stdout}") + self.log.debug(f"From command: {script}") + return False + + experiment_file = working_directory / "imported.expt" + with experiment_file.open("r") as f: + experiment = json.load(f) + trusted_range = experiment["detector"][0]["panels"][0]["trusted_range"][1] + + reflection_file = working_directory / "strong.refl" + reflections = flex.reflection_table.from_file(reflection_file) + counts_hist = self.collect_counts_from_reflections(reflections) + + num_counts = list(counts_hist.keys()) + num_pixels = list(counts_hist.values()) + + index_of_pixel_percentile = self.get_percentile_index( + num_pixels, pixel_percentile + ) + counts_at_percentile = int(num_counts[index_of_pixel_percentile]) + + pixel_countrate_pct = counts_at_percentile / trusted_range + self.log.info( + f"The countrate percentage of the {pixel_percentile * 100}% most intense pixel is {pixel_countrate_pct * 100}% of the trusted value" + ) + scale_factor = target_countrate_pct / pixel_countrate_pct + + scaled_transmission = min(1, (transmission * scale_factor) / 100) + self.log.info(f"Scaled transmission is : {scaled_transmission}") + + self.recwrap.send_to( + "strategy", + {"parameters": {"scaled_transmission": float(scaled_transmission)}}, + ) + + results_directory.mkdir(parents=True, exist_ok=True) + log_file_name = [ + f for f in working_directory.glob("*.out") if "slurm" in str(f) + ][0] + output_files = { + "dials.find_spots.log": "dials.find_spots.log", + log_file_name: "estimate_transmission.log", + } + for src_file_name, dest_file_name in output_files.items(): + source_file = working_directory / src_file_name + destination = results_directory / dest_file_name + + if not source_file.exists(): + self.log.info(f"{source_file=} does not exsist") + return False + + self.log.info(f"Copying {str(source_file)} to {str(destination)}") + shutil.copy(source_file, destination) + + self.draw_plot(num_counts, num_pixels) + self.save_hist_to_json(counts_hist, trusted_range, results_directory) + + self.log.info("Done.") + return True + + def collect_counts_from_reflections(self, reflections): + "Iterate through the shoeboxes to a reflection and generate a pixel histogram" + + shoeboxes = reflections["shoebox"] + counter = Counter() + for sbox in shoeboxes: + counter.update(sbox.data.as_numpy_array().ravel()) + + sorted_counter = sorted(counter.items()) + return {str(int(k)): v for k, v in sorted_counter} + + def get_percentile_index(self, num_pixels, percentile): + threshold = sum(num_pixels) * percentile + + for i, cum_sum in enumerate(accumulate(num_pixels)): + if cum_sum >= threshold: + return i + + return len(num_pixels) + + def save_hist_to_json(self, hist, max_trusted_value, results_dir): + results_path = results_dir / "overload.json" + self.log.info(f"Saving counts histogram to {str(results_path)}") + with open(results_path, "w") as f: + json.dump( + {"counts": hist, "overload_limit": max_trusted_value}, f, indent=2 + ) + + self.log.info("Saved.") + + def draw_plot(self, counts, pixels): + """Create ASCII art histogram""" + + self.log.info("Plotting pixel intensities...") + + width, height = 60, 20 + title = "'Pixel intensity distribution'" + xlabel = "'Num counts'" + ylabel = "'Counts'" + + command = ["gnuplot"] + plot_commands = [ + "set term dumb %d %d" % (width, height - 2), + "set logscale y", + "set logscale x", + "set ytics out", + "set title %s" % title, + "set xlabel %s" % xlabel, + "set ylabel %s offset character %d,0" % (ylabel, len(ylabel) // 2), + ] + + data_string = "\n".join(f"{x} {y}" for x, y in zip(counts[1:], pixels[1:])) + + plot_commands.append("plot '-' using 1:2 title '' with lines") + plot_commands.append(data_string) + + try: + result = subprocess.run( + command, input="\n".join(plot_commands), text=True, capture_output=True + ) + except (OSError, subprocess.TimeoutExpired) as e: + self.log.info("Error plotting counts vs pixels") + self.log.info(e) + return + else: + self.log.info(result.stdout) diff --git a/src/dlstbx/wrapper/xia2_overload.py b/src/dlstbx/wrapper/xia2_overload.py deleted file mode 100644 index e8e3866c6..000000000 --- a/src/dlstbx/wrapper/xia2_overload.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -import json -import math -import subprocess - -from pathlib import Path -import shutil - -from dlstbx.wrapper import Wrapper - - -class Xia2OverloadWrapper(Wrapper): - _logger_name = "dlstbx.wrap.xia2_overload" - - def run(self): - assert hasattr(self, "recwrap"), "No recipewrapper object found" - - params = self.recwrap.recipe_step["job_parameters"] - working_directory = Path(params["working_directory"]) - results_directory = Path(params["results_directory"]) - - target_countrate_pct = float(params["target_countrate_pct"]) - oscillation = float(params["oscillation"]) - transmission = float(params["transmission"]) - - file = params["input_file"] - - command = [f"xia2.overload {file}"] - - result = subprocess.run( - command, shell=True, cwd=working_directory, capture_output=True - ) - - if result.returncode: - self.log.info(f"xia2.overload failed with return code {result.returncode}") - self.log.info(result.stderr) - self.log.debug(f"Command output:\n{result.stdout}") - return False - - results_directory.mkdir(parents=True, exist_ok=True) - output_file_name = "overload.json" - - source_file = working_directory / output_file_name - destination = results_directory / output_file_name - - if not source_file.exists(): - return False - - self.log.debug(f"Copying {str(source_file)} to {str(destination)}") - shutil.copy2(source_file, destination) - - self.record_result_individual_file( - { - "file_path": str(destination.parent), - "file_name": destination.name, - "file_type": "result", - } - ) - - with source_file.open("r") as f: - data = json.load(f) - counts = data["counts"] - overload_limit = float(data["overload_limit"]) - - max_count = float(list(counts)[-1]) - - mosaicity_corr = params.get("mosaicity_correction", False) - average_to_peak = ( - self.mosaicity_correction(mosaicity_corr, oscillation) - if mosaicity_corr - else 1 - ) - - saturation = (max_count / overload_limit) * average_to_peak - scale_factor = target_countrate_pct / saturation - - scaled_transmission = min(1, ( transmission * scale_factor ) / 100) - - self.recwrap.send_to("strategy", {"parameters": {"scaled_transmission": scaled_transmission}}) - self.log.info("Done.") - return True - - def mosaicity_correction(self, moscaicity_coefficent: float, oscillation: float): - delta_z = oscillation / (moscaicity_coefficent) * math.sqrt(2) - average_to_peak = ( - math.sqrt(math.pi) * delta_z * math.erf(delta_z) - + math.exp(-(delta_z * delta_z)) - - 1 - ) / (delta_z * delta_z) - self.log.info("Average-to-peak intensity ratio: %f", average_to_peak) - return average_to_peak