Skip to content

Updated python CARFAC to receive input in Pascals#34

Open
JasonMH17 wants to merge 5 commits intogoogle:masterfrom
JasonMH17:dbspl_scaling
Open

Updated python CARFAC to receive input in Pascals#34
JasonMH17 wants to merge 5 commits intogoogle:masterfrom
JasonMH17:dbspl_scaling

Conversation

@JasonMH17
Copy link
Contributor

This update adds scaling of conventional audio input in Pascals to the CARFAC-specific scaling i.e. from 94 to 107 dB SPL for stimuli with RMS equal to 1.

Additional changes are made to the test stimuli so that they are scaled in Pascals.

@JasonMH17
Copy link
Contributor Author

@dicklyon This now incorporates changes so that input scaling can be specified in parameters but defaults to input in Pascals.

Copy link
Collaborator

@dicklyon dicklyon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments on what I think would make the code more clear for future readers, etc.

class EarHypers:
"""Hyperparameters (tagged as static in `jax.jit`) of 1 ear."""

input_scale: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an int? Float/double would make more sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been replaced by float

"""All the parameters set manually for designing CARFAC."""

fs: float = 22050.0
input_scale: int = 94
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the hint that it's in dB SPL. Better hint would be in the name, eg "input_scale_dbspl".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_scale -> input_scale_dbspl throughout Numpy and JAX carfac.py

)

def __init__(self, fs=22050.0, n_ears=1, use_delay_buffer=False):
def __init__(self, fs=22050.0, input_scale=94, n_ears=1, use_delay_buffer=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how this works in Python, but if there's a positional parameter option, and there might be calls to this with n_ears in that position, then this would break them, no? Better to put the new parameter last?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't implemented as of yet as typically whenever there are changes from the default code, I have seen the keyword being used. equally I have made input_scale_dbspl the second parameter in CarfacDesignParameters because, like fs, it is independent of n_ears and therefore sits outside the vast majority of params in "ear_params". Happy to change if you think this might still be an issue.


Args:
fs: Samples per second.
input_scale: scale for input waves. By default, input is expected in Pascals i.e. 94 dB SPL @ RMS=1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still doesn't say that the parameter is in dB SPL with default 94. And what does "CARFAC is considered as 107 dB SPL @ RMS=1" mean. Probably it means "while the original v1 and v2 CARFAC used an input scale of 107 dB SPL (for RMS = 1)".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as suggested


Args:
input_waves: the audio input.
input_waves: the audio input in Pascals.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"in Pascals" seems wrong. Maybe "in Pascals with default input_scale". Or is this now after the scaling has been accounted for. It's hard for me to follow in the diffs, which don't show me enough context, and if I try to see more context the file way too big and I get lost. I'm not very good at this github thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as suggested...as you indicate the current changed means that unless explicitly specified with "input_scale_dbspl" the expected input audio should be in pascals.

n_ears = 1
random_generator = jax.random.PRNGKey(random_seed)
run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears))
run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) * 10 ** ((107-94)/20)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good thing you're in Python 3. Many languages, including C and Python 2, would evaluate (107-94)/20 as 13/20 -> 0. Include a decimate point or several to make the point more clear, in case someone ports it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually removed this as this existed prior to "input_scale_dbspl" parameter...and it is probably best to use that instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially to confirm that calling this parameter does what it says on the tin.

n_samp = 200
random_generator = jax.random.PRNGKey(random_seed)
run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears))
run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) * 10 ** ((107-94)/20)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to run this test at 107 dB SPL? Just for compatibility with old test? Better say so if that's it. Did we run all the various tests at this high level? Should we fix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests date from CARFAC v2 therefore all inputs were at native CARFAC scaling of 107 dB SPL for RMS =1. As noted above, I have changed "input_scale_dbspl" to account for the changes to carfac.py so that the test data matches golden data. As for the absolute level, I have left it for what I presume is RMS=1. Happy to bring the level down and generate new golden data if desired!!

the input_waves are assumed to be sampled at the same rate as the
CARFAC is designed for; a resampling may be needed before calling this.

input_waves are considered as being in Pascals and therefore should equal
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of "should equal", say "their level is"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed as suggested


Args:
fs: Samples per second.
input_scale: scale for input waves. By default, input is expected in Pascals i.e. 94 dB SPL @ RMS=1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct SI unit name is "pascals" with lowercase P. No SI units are capitalized, except in abbreviations, e.g. P.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as indicated


fs = 22050.0
t = np.arange(0, 1, 1 / fs) # A second of noise.
amplitude = 1e-3 # -70 dBFS, around 30 or 40 dB SPL
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to stay compatible by changing the test waveform amplitude here, rather than first assuming 107 and then converting to 94? Or leave a hint about how we got here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we now use the CARFAC v1/v2 input scaling, this can remain the same I guess?!

…ffectively. Also reverted to using 107 dB SPL scaling for carfac_test.py by implementing input_scale_dbspl parameter in carfac_design function
…ectively. As in Numpy carfac_test.py, input scaling is reverted to native CARFAC v1/v2 107 dB SPL. Title added for Numpy/JAX carfac_test.py files to more easily differentiate
@JasonMH17
Copy link
Contributor Author

@dicklyon Thanks for the comments...I have made some changes as suggested and also pivoted to testing the input_scale_dbspl parameter within CARFACDesignParameters in the test functions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants