# Copyright 2023-2024 Blue Brain Project / EPFL
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the expected data structures associated with the stimulus defined in
simulation configs.
Run-time validates the data via Pydantic.
"""
from __future__ import annotations
from dataclasses import field
from enum import Enum
from typing import Optional
import warnings
from pydantic import field_validator, NonNegativeFloat, PositiveFloat
from pydantic.dataclasses import dataclass
# create an enum for StimulusMode with Current and Conductance values
[docs]
class ClampMode(Enum):
"""Current clamp or conductance (dynamic) clamp."""
CURRENT = "current_clamp"
CONDUCTANCE = "conductance"
[docs]
class Pattern(Enum):
"""Enum that defaults to SONATA values.
Has blueconfig overload.
"""
NOISE = "noise"
HYPERPOLARIZING = "hyperpolarizing"
PULSE = "pulse"
LINEAR = "linear"
RELATIVE_LINEAR = "relative_linear"
SYNAPSE_REPLAY = "synapse_replay"
SHOT_NOISE = "shot_noise"
RELATIVE_SHOT_NOISE = "relative_shot_noise"
ORNSTEIN_UHLENBECK = "ornstein_uhlenbeck"
RELATIVE_ORNSTEIN_UHLENBECK = "relative_ornstein_uhlenbeck"
SINUSOIDAL = "sinusoidal"
SECLAMP = "seclamp"
SUBTHRESHOLD = "subthreshold"
@classmethod
def from_blueconfig(cls, pattern: str) -> Pattern:
if pattern == "Noise":
return Pattern.NOISE
elif pattern == "Hyperpolarizing":
return Pattern.HYPERPOLARIZING
elif pattern == "Pulse":
return Pattern.PULSE
elif pattern == "RelativeLinear":
return Pattern.RELATIVE_LINEAR
elif pattern == "SynapseReplay":
return Pattern.SYNAPSE_REPLAY
elif pattern == "ShotNoise":
return Pattern.SHOT_NOISE
elif pattern == "RelativeShotNoise":
return Pattern.RELATIVE_SHOT_NOISE
elif pattern == "OrnsteinUhlenbeck":
return Pattern.ORNSTEIN_UHLENBECK
elif pattern == "RelativeOrnsteinUhlenbeck":
return Pattern.RELATIVE_ORNSTEIN_UHLENBECK
elif pattern == "SubThreshold":
return Pattern.SUBTHRESHOLD
else:
raise ValueError(f"Unknown pattern {pattern}")
@classmethod
def from_sonata(cls, pattern: str) -> Pattern:
if pattern == "noise":
return Pattern.NOISE
elif pattern == "hyperpolarizing":
return Pattern.HYPERPOLARIZING
elif pattern == "pulse":
return Pattern.PULSE
elif pattern == "linear":
return Pattern.LINEAR
elif pattern == "relative_linear":
return Pattern.RELATIVE_LINEAR
elif pattern == "synapse_replay":
return Pattern.SYNAPSE_REPLAY
elif pattern == "shot_noise":
return Pattern.SHOT_NOISE
elif pattern == "relative_shot_noise":
return Pattern.RELATIVE_SHOT_NOISE
elif pattern == "ornstein_uhlenbeck":
return Pattern.ORNSTEIN_UHLENBECK
elif pattern == "relative_ornstein_uhlenbeck":
return Pattern.RELATIVE_ORNSTEIN_UHLENBECK
elif pattern == "sinusoidal":
return Pattern.SINUSOIDAL
elif pattern == "seclamp":
return Pattern.SECLAMP
elif pattern == "subthreshold":
return Pattern.SUBTHRESHOLD
else:
raise ValueError(f"Unknown pattern {pattern}")
@dataclass(frozen=True, config=dict(extra="forbid"))
class Stimulus:
target: str
delay: NonNegativeFloat
duration: NonNegativeFloat
node_set: Optional[str] = field(default=None, kw_only=True)
compartment_set: Optional[str] = field(default=None, kw_only=True)
@classmethod
def from_blueconfig(cls, stimulus_entry: dict) -> Optional[Stimulus]:
pattern = Pattern.from_blueconfig(stimulus_entry["Pattern"])
mode_str = stimulus_entry.get("Mode", "Current").lower()
if mode_str == "current":
mode = ClampMode.CURRENT
elif mode_str == "conductance":
mode = ClampMode.CONDUCTANCE
else:
raise ValueError(f"Unknown clamp mode {mode_str}")
if pattern == Pattern.NOISE:
return Noise(
target=stimulus_entry["Target"],
delay=stimulus_entry["Delay"],
duration=stimulus_entry["Duration"],
mean_percent=stimulus_entry["MeanPercent"],
variance=stimulus_entry["Variance"],
node_set=stimulus_entry["Target"],
compartment_set=None,
)
elif pattern == Pattern.HYPERPOLARIZING:
return Hyperpolarizing(
target=stimulus_entry["Target"],
delay=stimulus_entry["Delay"],
duration=stimulus_entry["Duration"],
node_set=stimulus_entry["Target"],
compartment_set=None,
)
elif pattern == Pattern.PULSE:
return Pulse(
target=stimulus_entry["Target"],
delay=stimulus_entry["Delay"],
duration=stimulus_entry["Duration"],
amp_start=stimulus_entry["AmpStart"],
width=stimulus_entry["Width"],
frequency=stimulus_entry["Frequency"],
node_set=stimulus_entry["Target"],
compartment_set=None,
)
elif pattern == Pattern.RELATIVE_LINEAR:
return RelativeLinear(
target=stimulus_entry["Target"],
delay=stimulus_entry["Delay"],
duration=stimulus_entry["Duration"],
percent_start=stimulus_entry["PercentStart"],
percent_end=stimulus_entry["PercentEnd"],
node_set=stimulus_entry["Target"],
compartment_set=None,
)
elif pattern == Pattern.SYNAPSE_REPLAY:
warnings.warn("Ignoring syanpse replay stimulus as it is not supported")
return None
elif pattern == Pattern.SHOT_NOISE:
return ShotNoise(
target=stimulus_entry["Target"],
delay=stimulus_entry["Delay"],
duration=stimulus_entry["Duration"],
dt=stimulus_entry.get("Dt", 0.25),
rise_time=stimulus_entry["RiseTime"],
decay_time=stimulus_entry["DecayTime"],
rate=stimulus_entry["Rate"],
amp_mean=stimulus_entry["AmpMean"],
amp_var=stimulus_entry["AmpVar"],
seed=stimulus_entry.get("Seed", None),
mode=mode,
reversal=stimulus_entry.get("Reversal", 0.0),
node_set=stimulus_entry["Target"],
compartment_set=None,
)
elif pattern == Pattern.RELATIVE_SHOT_NOISE:
return RelativeShotNoise(
target=stimulus_entry["Target"],
delay=stimulus_entry["Delay"],
duration=stimulus_entry["Duration"],
dt=stimulus_entry.get("Dt", 0.25),
rise_time=stimulus_entry["RiseTime"],
decay_time=stimulus_entry["DecayTime"],
mean_percent=stimulus_entry["MeanPercent"],
sd_percent=stimulus_entry["SDPercent"],
relative_skew=stimulus_entry.get("RelativeSkew", 0.5),
seed=stimulus_entry.get("Seed", None),
mode=mode,
reversal=stimulus_entry.get("Reversal", 0.0),
node_set=stimulus_entry["Target"],
compartment_set=None,
)
elif pattern == Pattern.ORNSTEIN_UHLENBECK:
return OrnsteinUhlenbeck(
target=stimulus_entry["Target"],
delay=stimulus_entry["Delay"],
duration=stimulus_entry["Duration"],
dt=stimulus_entry.get("Dt", 0.25),
tau=stimulus_entry["Tau"],
sigma=stimulus_entry["Sigma"],
mean=stimulus_entry["Mean"],
seed=stimulus_entry.get("Seed", None),
mode=mode,
reversal=stimulus_entry.get("Reversal", 0.0),
node_set=stimulus_entry["Target"],
compartment_set=None,
)
elif pattern == Pattern.RELATIVE_ORNSTEIN_UHLENBECK:
return RelativeOrnsteinUhlenbeck(
target=stimulus_entry["Target"],
delay=stimulus_entry["Delay"],
duration=stimulus_entry["Duration"],
dt=stimulus_entry.get("Dt", 0.25),
tau=stimulus_entry["Tau"],
mean_percent=stimulus_entry["MeanPercent"],
sd_percent=stimulus_entry["SDPercent"],
seed=stimulus_entry.get("Seed", None),
mode=mode,
reversal=stimulus_entry.get("Reversal", 0.0),
node_set=stimulus_entry["Target"],
compartment_set=None,
)
elif pattern == Pattern.SUBTHRESHOLD:
return SubThreshold(
target=stimulus_entry["Target"],
delay=stimulus_entry["Delay"],
duration=stimulus_entry["Duration"],
percent_less=stimulus_entry["PercentLess"],
node_set=stimulus_entry["Target"],
compartment_set=None,
)
else:
raise ValueError(f"Unknown pattern {pattern}")
@classmethod
def from_sonata(cls, stimulus_entry: dict, config_dir: Optional[str] = None) -> Optional[Stimulus]:
pattern = Pattern.from_sonata(stimulus_entry["module"])
node_set = stimulus_entry.get("node_set")
compartment_set = stimulus_entry.get("compartment_set")
if node_set is not None and compartment_set is not None:
raise ValueError("Stimulus entry must not contain both 'node_set' and 'compartment_set'.")
target_name: str | None = compartment_set if compartment_set is not None else node_set
if target_name is None:
raise ValueError("Stimulus entry must contain either 'node_set' or 'compartment_set'.")
if pattern == Pattern.NOISE:
has_mean = "mean" in stimulus_entry
has_mean_percent = "mean_percent" in stimulus_entry
if has_mean == has_mean_percent:
raise ValueError("Noise input must contain exactly one of 'mean' or 'mean_percent'.")
return Noise(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
mean=stimulus_entry.get("mean"),
mean_percent=stimulus_entry.get("mean_percent"),
variance=stimulus_entry["variance"],
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.HYPERPOLARIZING:
return Hyperpolarizing(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.PULSE:
return Pulse(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
amp_start=stimulus_entry["amp_start"],
width=stimulus_entry["width"],
frequency=stimulus_entry["frequency"],
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.LINEAR:
return Linear(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
amp_start=stimulus_entry["amp_start"],
amp_end=stimulus_entry.get("amp_end", stimulus_entry["amp_start"]),
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.RELATIVE_LINEAR:
return RelativeLinear(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
percent_start=stimulus_entry["percent_start"],
percent_end=stimulus_entry.get("percent_end", stimulus_entry["percent_start"]),
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.SYNAPSE_REPLAY:
return SynapseReplay(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
spike_file=stimulus_entry["spike_file"],
config_dir=config_dir,
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.SHOT_NOISE:
return ShotNoise(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
dt=stimulus_entry.get("dt", 0.25),
rise_time=stimulus_entry["rise_time"],
decay_time=stimulus_entry["decay_time"],
rate=stimulus_entry["rate"],
amp_mean=stimulus_entry["amp_mean"],
amp_var=stimulus_entry["amp_var"],
seed=stimulus_entry.get("random_seed", None),
mode=ClampMode(stimulus_entry.get("input_type", "current_clamp").lower()),
reversal=stimulus_entry.get("reversal", 0.0),
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.RELATIVE_SHOT_NOISE:
return RelativeShotNoise(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
dt=stimulus_entry.get("dt", 0.25),
rise_time=stimulus_entry["rise_time"],
decay_time=stimulus_entry["decay_time"],
mean_percent=stimulus_entry["mean_percent"],
sd_percent=stimulus_entry["sd_percent"],
relative_skew=stimulus_entry.get("relative_skew", 0.5),
seed=stimulus_entry.get("random_seed", None),
mode=ClampMode(stimulus_entry.get("input_type", "current_clamp").lower()),
reversal=stimulus_entry.get("reversal", 0.0),
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.ORNSTEIN_UHLENBECK:
return OrnsteinUhlenbeck(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
dt=stimulus_entry.get("dt", 0.25),
tau=stimulus_entry["tau"],
sigma=stimulus_entry["sigma"],
mean=stimulus_entry["mean"],
seed=stimulus_entry.get("random_seed", None),
mode=ClampMode(stimulus_entry.get("input_type", "current_clamp").lower()),
reversal=stimulus_entry.get("reversal", 0.0),
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.RELATIVE_ORNSTEIN_UHLENBECK:
return RelativeOrnsteinUhlenbeck(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
dt=stimulus_entry.get("dt", 0.25),
tau=stimulus_entry["tau"],
mean_percent=stimulus_entry["mean_percent"],
sd_percent=stimulus_entry["sd_percent"],
seed=stimulus_entry.get("random_seed", None),
mode=ClampMode(stimulus_entry.get("input_type", "current_clamp").lower()),
reversal=stimulus_entry.get("reversal", 0.0),
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.SINUSOIDAL:
return Sinusoidal(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
amp_start=stimulus_entry["amp_start"],
frequency=stimulus_entry["frequency"],
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.SECLAMP:
return SEClamp(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
voltage=stimulus_entry["voltage"],
durations=stimulus_entry.get("duration_levels", None),
voltages=stimulus_entry.get("voltage_levels", None),
series_resistance=stimulus_entry.get("series_resistance", 0.01),
node_set=node_set,
compartment_set=compartment_set,
)
elif pattern == Pattern.SUBTHRESHOLD:
return SubThreshold(
target=target_name,
delay=stimulus_entry["delay"],
duration=stimulus_entry["duration"],
percent_less=stimulus_entry["percent_less"],
node_set=node_set,
compartment_set=compartment_set,
)
else:
raise ValueError(f"Unknown pattern {pattern}")
@dataclass(frozen=True, config=dict(extra="forbid"))
class Noise(Stimulus):
variance: float
mean: Optional[float] = None # nA
mean_percent: Optional[float] = None # % of threshold
def __post_init__(self):
# exactly one of mean / mean_percent must be provided
if (self.mean is None) == (self.mean_percent is None):
raise ValueError("Noise stimulus must define exactly one of 'mean' or 'mean_percent'.")
if self.variance < 0:
raise ValueError("'variance' must be >= 0.")
@dataclass(frozen=True, config=dict(extra="forbid"))
class Hyperpolarizing(Stimulus):
...
@dataclass(frozen=True, config=dict(extra="forbid"))
class Pulse(Stimulus):
amp_start: float
width: float
frequency: float
@dataclass(frozen=True, config=dict(extra="forbid"))
class Linear(Stimulus):
amp_start: float
amp_end: float
@dataclass(frozen=True, config=dict(extra="forbid"))
class RelativeLinear(Stimulus):
percent_start: float
percent_end: float
@dataclass(frozen=True, config=dict(extra="forbid"))
class SynapseReplay(Stimulus):
spike_file: str
config_dir: Optional[str] = None
@dataclass(frozen=True, config=dict(extra="forbid"))
class ShotNoise(Stimulus):
rise_time: float
decay_time: float
rate: float
amp_mean: float
amp_var: float
dt: float = 0.25
seed: Optional[int] = None
mode: ClampMode = ClampMode.CURRENT
reversal: float = 0.0
@field_validator("decay_time")
@classmethod
def decay_time_gt_rise_time(cls, v, values):
if v <= values.data["rise_time"]:
raise ValueError("decay_time must be greater than rise_time")
return v
@dataclass(frozen=True, config=dict(extra="forbid"))
class RelativeShotNoise(Stimulus):
rise_time: float
decay_time: float
mean_percent: float
sd_percent: float
relative_skew: float = 0.5
dt: float = 0.25
seed: Optional[int] = None
mode: ClampMode = ClampMode.CURRENT
reversal: float = 0.0
@field_validator("decay_time")
@classmethod
def decay_time_gt_rise_time(cls, v, values):
if v <= values.data["rise_time"]:
raise ValueError("decay_time must be greater than rise_time")
return v
@field_validator("relative_skew")
@classmethod
def relative_skew_in_range(cls, v):
if v < 0.0 or v > 1.0:
raise ValueError("relative skewness must be in [0,1]")
return v
@dataclass(frozen=True, config=dict(extra="forbid"))
class OrnsteinUhlenbeck(Stimulus):
tau: float
sigma: PositiveFloat
mean: float
dt: float = 0.25
seed: Optional[int] = None
mode: ClampMode = ClampMode.CURRENT
reversal: float = 0.0
@field_validator("mean")
@classmethod
def mean_in_range(cls, v, values):
if v < 0 and abs(v) > 2 * values.data["sigma"]:
warnings.warn(
"mean is outside of range [0, 2*sigma],",
" ornstein uhlenbeck signal is mostly zero.",
)
return v
@dataclass(frozen=True, config=dict(extra="forbid"))
class RelativeOrnsteinUhlenbeck(Stimulus):
tau: float
mean_percent: float
sd_percent: float
dt: float = 0.25
seed: Optional[int] = None
mode: ClampMode = ClampMode.CURRENT
reversal: float = 0.0
@dataclass(frozen=True, config=dict(extra="forbid"))
class Sinusoidal(Stimulus):
amp_start: float
frequency: float
@dataclass(frozen=True, config=dict(extra="forbid"))
class SEClamp(Stimulus):
voltage: float
durations: Optional[list[float]]
voltages: Optional[list[float]]
series_resistance: float
[docs]
@dataclass(frozen=True, config=dict(extra="forbid"))
class SubThreshold(Stimulus):
"""Injects a current step at some percent below a cell's threshold."""
percent_less: float