Skip to content

Commit eca8905

Browse files
authored
Merge pull request #2263 from dopplershift/moist-lapse
Refactor moist_lapse
2 parents 1153590 + 20deddb commit eca8905

5 files changed

Lines changed: 62 additions & 26 deletions

File tree

src/metpy/calc/thermo.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ def moist_lapse(pressure, temperature, reference_pressure=None):
295295
Renamed ``ref_pressure`` parameter to ``reference_pressure``
296296
297297
"""
298-
def dt(t, p):
299-
t = units.Quantity(t, temperature.units)
300-
p = units.Quantity(p, pressure.units)
298+
def dt(p, t):
299+
t = units.Quantity(t, 'kelvin')
300+
p = units.Quantity(p, 'mbar')
301301
rs = saturation_mixing_ratio(p, t)
302302
frac = ((mpconsts.Rd * t + mpconsts.Lv * rs)
303303
/ (mpconsts.Cp_d + (mpconsts.Lv * mpconsts.Lv * rs * mpconsts.epsilon
@@ -307,39 +307,55 @@ def dt(t, p):
307307
pressure = np.atleast_1d(pressure)
308308
if reference_pressure is None:
309309
reference_pressure = pressure[0]
310-
311-
if np.isnan(reference_pressure):
312-
return units.Quantity(np.full(pressure.shape, np.nan), temperature.units)
313-
314310
pressure = pressure.to('mbar')
315311
reference_pressure = reference_pressure.to('mbar')
316312
org_units = temperature.units
317-
temperature = np.atleast_1d(temperature).to('kelvin')
313+
temperature = np.atleast_1d(temperature).m_as('kelvin')
314+
315+
if np.isnan(reference_pressure) or np.all(np.isnan(temperature)):
316+
return units.Quantity(np.full((temperature.size, pressure.size), np.nan), org_units)
318317

319318
pres_decreasing = (pressure[0] > pressure[-1])
320319
if pres_decreasing:
321320
# Everything is easier if pressures are in increasing order
322321
pressure = pressure[::-1]
323322

324-
ref_pres_idx = np.searchsorted(pressure.m, reference_pressure.m, side='right')
325-
ret_temperatures = np.empty((0, temperature.shape[0]))
323+
# It would be preferable to use a regular solver like RK45, but as of scipy 1.8.0
324+
# anything other than LSODA goes into an infinite loop when given NaNs for y0.
325+
solver_args = {'fun': dt, 'y0': temperature,
326+
'method': 'LSODA', 'atol': 1e-7, 'rtol': 1.5e-8}
326327

327-
if _greater_or_close(reference_pressure, pressure.min()):
328-
# Integrate downward in pressure
329-
pres_down = np.append(reference_pressure.m, pressure[(ref_pres_idx - 1)::-1].m)
330-
trace_down = si.odeint(dt, temperature.m.squeeze(), pres_down.squeeze())
331-
ret_temperatures = np.concatenate((ret_temperatures, trace_down[:0:-1]))
332-
333-
if reference_pressure < pressure.max():
334-
# Integrate upward in pressure
335-
pres_up = np.append(reference_pressure.m, pressure[ref_pres_idx:].m)
336-
trace_up = si.odeint(dt, temperature.m.squeeze(), pres_up.squeeze())
337-
ret_temperatures = np.concatenate((ret_temperatures, trace_up[1:]))
328+
# Need to handle close points to avoid an error in the solver
329+
close = np.isclose(pressure, reference_pressure)
330+
if np.any(close):
331+
ret = np.broadcast_to(temperature[:, np.newaxis], (temperature.size, np.sum(close)))
332+
else:
333+
ret = np.empty((temperature.size, 0), dtype=temperature.dtype)
334+
335+
# Do we have any points above the reference pressure
336+
points_above = (pressure < reference_pressure) & ~close
337+
if np.any(points_above):
338+
# Integrate upward--need to flip so values are properly ordered from ref to min
339+
press_side = pressure[points_above][::-1].m
340+
341+
# Flip on exit so t values correspond to increasing pressure
342+
trace = si.solve_ivp(t_span=(reference_pressure.m, press_side[-1]),
343+
t_eval=press_side, **solver_args).y[..., ::-1]
344+
ret = np.concatenate((trace, ret), axis=-1)
345+
346+
# Do we have any points below the reference pressure
347+
points_below = ~points_above & ~close
348+
if np.any(points_below):
349+
# Integrate downward
350+
press_side = pressure[points_below].m
351+
trace = si.solve_ivp(t_span=(reference_pressure.m, press_side[-1]),
352+
t_eval=press_side, **solver_args).y
353+
ret = np.concatenate((ret, trace), axis=-1)
338354

339355
if pres_decreasing:
340-
ret_temperatures = ret_temperatures[::-1]
356+
ret = ret[..., ::-1]
341357

342-
return units.Quantity(ret_temperatures.T.squeeze(), 'kelvin').to(org_units)
358+
return units.Quantity(ret.squeeze(), 'kelvin').to(org_units)
343359

344360

345361
@exporter.export

src/metpy/plots/skewt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,7 @@ def plot_moist_adiabats(self, t0=None, pressure=None, **kwargs):
545545
pressure = units.Quantity(np.linspace(*self.ax.get_ylim()), 'mbar')
546546

547547
# Assemble into data for plotting
548-
t = moist_lapse(pressure, t0[:, np.newaxis],
549-
units.Quantity(1000., 'mbar')).to(units.degC)
548+
t = moist_lapse(pressure, t0, units.Quantity(1000., 'mbar')).to(units.degC)
550549
linedata = [np.vstack((ti.m, pressure.m)).T for ti in t]
551550

552551
# Add to plot

tests/calc/test_thermo.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# SPDX-License-Identifier: BSD-3-Clause
44
"""Test the `thermo` module."""
55

6+
import warnings
7+
68
import numpy as np
79
import pytest
810
import xarray as xr
@@ -131,12 +133,31 @@ def test_moist_lapse_ref_pressure():
131133
assert_array_almost_equal(temp, true_temp, 2)
132134

133135

136+
def test_moist_lapse_multiple_temps():
137+
"""Test moist_lapse with multiple starting temperatures."""
138+
temp = moist_lapse(np.array([1050., 800., 600., 500., 400.]) * units.mbar,
139+
np.array([19.85, np.nan, 19.85]) * units.degC, 1000. * units.mbar)
140+
true_temp = np.array([[294.76, 284.64, 272.81, 264.42, 252.91],
141+
[np.nan, np.nan, np.nan, np.nan, np.nan],
142+
[294.76, 284.64, 272.81, 264.42, 252.91]]) * units.kelvin
143+
assert_array_almost_equal(temp, true_temp, 2)
144+
145+
134146
def test_moist_lapse_scalar():
135147
"""Test moist_lapse when given a scalar desired pressure and a reference pressure."""
136148
temp = moist_lapse(np.array([800.]) * units.mbar, 19.85 * units.degC, 1000. * units.mbar)
137149
assert_almost_equal(temp, 284.64 * units.kelvin, 2)
138150

139151

152+
def test_moist_lapse_close_start():
153+
"""Test that we behave correctly with a reference pressure close to an actual pressure."""
154+
with warnings.catch_warnings(record=True) as record:
155+
temp = moist_lapse(units.Quantity(1000, 'hPa'), 0 * units.degC,
156+
units.Quantity(1000., 'mbar'))
157+
assert len(record) == 0
158+
assert_almost_equal(temp, units.Quantity(0., 'degC'))
159+
160+
140161
def test_moist_lapse_uniform():
141162
"""Test moist_lapse when given a uniform array of pressures."""
142163
temp = moist_lapse(np.array([900., 900., 900.]) * units.hPa, 20. * units.degC)
@@ -158,7 +179,7 @@ def test_moist_lapse_nan_ref_press():
158179
def test_moist_lapse_downwards():
159180
"""Test moist_lapse when integrating downwards (#2128)."""
160181
temp = moist_lapse(units.Quantity([600, 700], 'mbar'), units.Quantity(0, 'degC'))
161-
assert_almost_equal(temp, units.Quantity([0, 6.47748353], units.degC))
182+
assert_almost_equal(temp, units.Quantity([0, 6.47748353], units.degC), 4)
162183

163184

164185
@pytest.mark.parametrize('direction', (1, -1))
-11 Bytes
Loading
0 Bytes
Loading

0 commit comments

Comments
 (0)