Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 125 additions & 29 deletions corelib/src/libs/SireCAS/lambdaschedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ static RegisterMetaType<LambdaSchedule> r_schedule;

QDataStream &operator<<(QDataStream &ds, const LambdaSchedule &schedule)
{
writeHeader(ds, r_schedule, 4);
writeHeader(ds, r_schedule, 5);

SharedDataStream sds(ds);

Expand All @@ -76,6 +76,7 @@ QDataStream &operator<<(QDataStream &ds, const LambdaSchedule &schedule)
<< schedule.stage_equations
<< schedule.mol_schedules
<< schedule.coupled_levers
<< schedule.stage_weights
<< static_cast<const Property &>(schedule);

return ds;
Expand All @@ -92,7 +93,7 @@ QDataStream &operator>>(QDataStream &ds, LambdaSchedule &schedule)
{
VersionID v = readHeader(ds, r_schedule);

if (v == 1 or v == 2 or v == 3 or v == 4)
if (v == 1 or v == 2 or v == 3 or v == 4 or v == 5)
{
SharedDataStream sds(ds);

Expand All @@ -116,6 +117,11 @@ QDataStream &operator>>(QDataStream &ds, LambdaSchedule &schedule)
_get_lever_name("torsion", "torsion_k");
}

if (v >= 5)
sds >> schedule.stage_weights;
else
schedule.stage_weights = QVector<double>(schedule.stage_names.count(), 1.0);

if (v < 3)
{
// need to make sure that the lever names are namespaced
Expand Down Expand Up @@ -156,7 +162,7 @@ QDataStream &operator>>(QDataStream &ds, LambdaSchedule &schedule)
}
}
else
throw version_error(v, "1, 2, 3, 4", r_schedule, CODELOC);
throw version_error(v, "1, 2, 3, 4, 5", r_schedule, CODELOC);

return ds;
}
Expand All @@ -177,7 +183,8 @@ LambdaSchedule::LambdaSchedule(const LambdaSchedule &other)
lever_names(other.lever_names), stage_names(other.stage_names),
default_equations(other.default_equations),
stage_equations(other.stage_equations),
coupled_levers(other.coupled_levers)
coupled_levers(other.coupled_levers),
stage_weights(other.stage_weights)
{
}

Expand All @@ -196,6 +203,8 @@ LambdaSchedule &LambdaSchedule::operator=(const LambdaSchedule &other)
stage_names = other.stage_names;
default_equations = other.default_equations;
stage_equations = other.stage_equations;
coupled_levers = other.coupled_levers;
stage_weights = other.stage_weights;
Property::operator=(other);
}

Expand All @@ -210,7 +219,9 @@ bool LambdaSchedule::operator==(const LambdaSchedule &other) const
lever_names == other.lever_names and
stage_names == other.stage_names and
default_equations == other.default_equations and
stage_equations == other.stage_equations;
stage_equations == other.stage_equations and
coupled_levers == other.coupled_levers and
stage_weights == other.stage_weights;
}

bool LambdaSchedule::operator!=(const LambdaSchedule &other) const
Expand Down Expand Up @@ -245,12 +256,27 @@ QString LambdaSchedule::toString() const

QStringList lines;

for (int i = 0; i < this->stage_names.count(); ++i)
bool any_non_default_weight = false;
for (const auto &w : this->stage_weights)
{
if (w != 1.0)
{
any_non_default_weight = true;
break;
}
}

lines.append(QString(" %1: %2")
.arg(this->stage_names[i])
.arg(this->default_equations[i].toOpenMMString()));
for (int i = 0; i < this->stage_names.count(); ++i)
{
if (any_non_default_weight)
lines.append(QString(" %1 (weight=%2): %3")
.arg(this->stage_names[i])
.arg(this->stage_weights[i])
.arg(this->default_equations[i].toOpenMMString()));
else
lines.append(QString(" %1: %2")
.arg(this->stage_names[i])
.arg(this->default_equations[i].toOpenMMString()));

auto keys = this->stage_equations[i].keys();
std::sort(keys.begin(), keys.end());
Expand Down Expand Up @@ -632,13 +658,27 @@ std::tuple<int, double> LambdaSchedule::resolve_lambda(double lambda_value) cons
return std::tuple<int, double>(this->nStages() - 1, 1.0);
}

double stage_width = 1.0 / this->nStages();
double total_weight = 0.0;
for (const auto &w : this->stage_weights)
total_weight += w;

double resolved = lambda_value / stage_width;
double cumulative = 0.0;
for (int i = 0; i < this->nStages(); ++i)
{
double stage_start = cumulative / total_weight;
double stage_width = this->stage_weights[i] / total_weight;
double stage_end = stage_start + stage_width;

if (lambda_value < stage_end)
{
double local_lambda = (lambda_value - stage_start) / stage_width;
return std::tuple<int, double>(i, local_lambda);
}

double stage = std::floor(resolved);
cumulative += this->stage_weights[i];
}

return std::tuple<int, double>(int(stage), resolved - stage);
return std::tuple<int, double>(this->nStages() - 1, 1.0);
}

/** Return the name of the stage that controls the forcefield parameters
Expand Down Expand Up @@ -670,16 +710,17 @@ void LambdaSchedule::clear()
this->stage_names.clear();
this->stage_equations.clear();
this->default_equations.clear();
this->stage_weights.clear();
this->constant_values = Values();
}

/** Append a morph stage onto this schedule. The morph stage is a
* standard stage that scales each forcefield parameter by
* (1-:lambda:).initial + :lambda:.final
*/
void LambdaSchedule::addMorphStage(const QString &name)
void LambdaSchedule::addMorphStage(const QString &name, double weight)
{
this->addStage(name, default_morph_equation);
this->addStage(name, default_morph_equation, weight);
}

/** Append a morph stage onto this schedule. The morph stage is a
Expand All @@ -704,9 +745,10 @@ void LambdaSchedule::addDecoupleStage(bool perturbed_is_decoupled)
* state if `perturbed_is_decoupled` is true, otherwise the
* reference state is decoupled.
*/
void LambdaSchedule::addDecoupleStage(const QString &name, bool perturbed_is_decoupled)
void LambdaSchedule::addDecoupleStage(const QString &name, bool perturbed_is_decoupled,
double weight)
{
this->addStage(name, default_morph_equation);
this->addStage(name, default_morph_equation, weight);

// we now need to ensure that the ghost/ghost and ghost-14 parameters are
// not perturbed
Expand Down Expand Up @@ -743,9 +785,10 @@ void LambdaSchedule::addAnnihilateStage(bool perturbed_is_annihilated)
* state if `perturbed_is_annihilated` is true, otherwise the
* reference state is annihilated.
*/
void LambdaSchedule::addAnnihilateStage(const QString &name, bool perturbed_is_annihilated)
void LambdaSchedule::addAnnihilateStage(const QString &name, bool perturbed_is_annihilated,
double weight)
{
this->addStage(name, default_morph_equation);
this->addStage(name, default_morph_equation, weight);
}

/** Sandwich the current set of stages with a charge-descaling and
Expand Down Expand Up @@ -796,21 +839,28 @@ void LambdaSchedule::addChargeScaleStages(double scale)
* a custom lever for this stage.
*/
void LambdaSchedule::prependStage(const QString &name,
const SireCAS::Expression &equation)
const SireCAS::Expression &equation,
double weight)
{
if (name == "*")
throw SireError::invalid_key(QObject::tr(
"The stage name '*' is reserved and cannot be used."),
CODELOC);

if (weight <= 0.0)
throw SireError::invalid_arg(QObject::tr(
"The stage weight must be positive. Got %1.")
.arg(weight),
CODELOC);

auto e = equation;

if (e == default_morph_equation)
e = default_morph_equation;

if (this->nStages() == 0)
{
this->appendStage(name, e);
this->appendStage(name, e, weight);
return;
}

Expand All @@ -823,6 +873,7 @@ void LambdaSchedule::prependStage(const QString &name,
this->stage_names.prepend(name);
this->default_equations.prepend(e);
this->stage_equations.prepend(QHash<QString, Expression>());
this->stage_weights.prepend(weight);
}

/** Append a stage called 'name' which uses the passed 'equation'
Expand All @@ -831,7 +882,8 @@ void LambdaSchedule::prependStage(const QString &name,
* a custom lever for this stage.
*/
void LambdaSchedule::appendStage(const QString &name,
const SireCAS::Expression &equation)
const SireCAS::Expression &equation,
double weight)
{
if (name == "*")
throw SireError::invalid_key(QObject::tr(
Expand All @@ -844,6 +896,12 @@ void LambdaSchedule::appendStage(const QString &name,
.arg(name),
CODELOC);

if (weight <= 0.0)
throw SireError::invalid_arg(QObject::tr(
"The stage weight must be positive. Got %1.")
.arg(weight),
CODELOC);

auto e = equation;

if (e == default_morph_equation)
Expand All @@ -852,6 +910,7 @@ void LambdaSchedule::appendStage(const QString &name,
this->stage_names.append(name);
this->default_equations.append(e);
this->stage_equations.append(QHash<QString, Expression>());
this->stage_weights.append(weight);
}

/** Insert a stage called 'name' at position `i` which uses the passed
Expand All @@ -861,26 +920,33 @@ void LambdaSchedule::appendStage(const QString &name,
*/
void LambdaSchedule::insertStage(int i,
const QString &name,
const SireCAS::Expression &equation)
const SireCAS::Expression &equation,
double weight)
{
if (name == "*")
throw SireError::invalid_key(QObject::tr(
"The stage name '*' is reserved and cannot be used."),
CODELOC);

if (weight <= 0.0)
throw SireError::invalid_arg(QObject::tr(
"The stage weight must be positive. Got %1.")
.arg(weight),
CODELOC);

auto e = equation;

if (e == default_morph_equation)
e = default_morph_equation;

if (i == 0)
{
this->prependStage(name, e);
this->prependStage(name, e, weight);
return;
}
else if (i >= this->nStages())
{
this->appendStage(name, e);
this->appendStage(name, e, weight);
return;
}

Expand All @@ -893,6 +959,7 @@ void LambdaSchedule::insertStage(int i,
this->stage_names.insert(i, name);
this->default_equations.insert(i, e);
this->stage_equations.insert(i, QHash<QString, Expression>());
this->stage_weights.insert(i, weight);
}

/** Remove the stage 'stage' */
Expand All @@ -906,6 +973,7 @@ void LambdaSchedule::removeStage(const QString &stage)
this->stage_names.removeAt(idx);
this->default_equations.removeAt(idx);
this->stage_equations.removeAt(idx);
this->stage_weights.removeAt(idx);
}

/** Append a stage called 'name' which uses the passed 'equation'
Expand All @@ -914,14 +982,15 @@ void LambdaSchedule::removeStage(const QString &stage)
* a custom lever for this stage.
*/
void LambdaSchedule::addStage(const QString &name,
const Expression &equation)
const Expression &equation,
double weight)
{
if (name == "*")
throw SireError::invalid_key(QObject::tr(
"The stage name '*' is reserved and cannot be used."),
CODELOC);

this->appendStage(name, equation);
this->appendStage(name, equation, weight);
}

/** Find the index of the stage called 'stage'. This returns
Expand All @@ -942,6 +1011,33 @@ int LambdaSchedule::find_stage(const QString &stage) const
return idx;
}

/** Set the relative weight of the stage 'stage' in lambda space.
* A stage with weight 2 occupies twice the lambda range as a
* stage with weight 1. Weights must be positive.
*/
void LambdaSchedule::setStageWeight(const QString &stage, double weight)
{
if (weight <= 0.0)
throw SireError::invalid_arg(QObject::tr(
"The stage weight must be positive. Got %1.")
.arg(weight),
CODELOC);

this->stage_weights[this->find_stage(stage)] = weight;
}

/** Return the relative weight of the stage 'stage' in lambda space. */
double LambdaSchedule::getStageWeight(const QString &stage) const
{
return this->stage_weights[this->find_stage(stage)];
}

/** Return the relative weights of all stages, in stage order. */
QVector<double> LambdaSchedule::getStageWeights() const
{
return this->stage_weights;
}

/** Set the default equation used to control levers for the
* stage 'stage' to 'equation'. This equation will be used
* to control any levers in this stage that don't have
Expand Down Expand Up @@ -1032,8 +1128,8 @@ void LambdaSchedule::removeEquation(const QString &stage,
* in sync.
*/
void LambdaSchedule::coupleLever(const QString &force, const QString &lever,
const QString &fallback_force,
const QString &fallback_lever)
const QString &fallback_force,
const QString &fallback_lever)
{
coupled_levers[_get_lever_name(force, lever)] =
_get_lever_name(fallback_force, fallback_lever);
Expand Down
Loading
Loading