Source code for bluecellulab.circuit_simulation

# Copyright 2023-2024 Blue Brain Project / EPFL

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ssim class of bluecellulab that loads a circuit simulation to do cell
simulations."""

from __future__ import annotations
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Optional
import logging
import warnings

from bluecellulab.circuit.gid_resolver import GidNamespace
from bluecellulab.reports.utils import prepare_recordings_for_reports
import neuron
import numpy as np
import pandas as pd
from pydantic.types import NonNegativeInt
from typing_extensions import deprecated

import bluecellulab
from bluecellulab.cell import CellDict
from bluecellulab.cell.sonata_proxy import SonataProxy
from bluecellulab.circuit import CellId, SimulationValidator, SynapseProperty
from bluecellulab.circuit.circuit_access import (
    CircuitAccess,
    BluepyCircuitAccess,
    SonataCircuitAccess,
    get_synapse_connection_parameters,
)
from bluecellulab.circuit.config import SimulationConfig
from bluecellulab.circuit.format import determine_circuit_format, CircuitFormat
from bluecellulab.circuit.node_id import create_cell_id, create_cell_ids
from bluecellulab.circuit.simulation_access import (
    BluepySimulationAccess,
    SimulationAccess,
    SonataSimulationAccess,
    _sample_array,
)
from bluecellulab.importer import load_mod_files
from bluecellulab.rngsettings import RNGSettings
from bluecellulab.simulation.neuron_globals import NeuronGlobals
from bluecellulab.stimulus.circuit_stimulus_definitions import (
    Noise,
    OrnsteinUhlenbeck,
    RelativeOrnsteinUhlenbeck,
    RelativeShotNoise,
    ShotNoise,
)
import bluecellulab.stimulus.circuit_stimulus_definitions as circuit_stimulus_definitions
from bluecellulab.exceptions import BluecellulabError
from bluecellulab.simulation import (
    set_global_condition_parameters,
)
from bluecellulab.simulation.modifications import apply_modifications
from bluecellulab.synapse.synapse_types import SynapseID

logger = logging.getLogger(__name__)


[docs] @deprecated("SSim will be removed, use CircuitSimulation instead.") class SSim: """Class that loads a circuit simulation to do cell simulations."""
[docs] class CircuitSimulation: """Class that loads a circuit simulation to do cell simulations.""" @load_mod_files def __init__( self, simulation_config: str | Path | SimulationConfig, dt: Optional[float] = None, record_dt: Optional[float] = None, base_seed: Optional[NonNegativeInt] = None, rng_mode: Optional[str] = None, print_cellstate: bool = False, parallel_context=None, save_time: Optional[float] = None, ): """ Parameters ---------- simulation_config : Absolute filename of the simulation config file. dt : Timestep of the simulation record_dt : Sampling interval of the recordings base_seed : Base seed used for this simulation. Setting this will override the value set in the simulation config. rng_mode : String with rng mode, if not specified mode is taken from simulation config. Possible values are Compatibility, Random123 and UpdatedMCell. print_cellstate: Flag to use NEURON prcellstate for simulation GIDs parallel_context: Optional NEURON ParallelContext to use for MPI runs. If provided, CircuitSimulation will reuse this object instead of creating a new ParallelContext internally. This is useful when the caller already initialized MPI and manages rank/nhost externally. save_time: Time (ms) when `prcellstate` is dumped. If None (default), dump at the end of the simulation (tstop). Use 0 to dump immediately after initialization. """ self.record_dt = record_dt self.circuit_format = determine_circuit_format(simulation_config) if self.circuit_format == CircuitFormat.SONATA: self.circuit_access: CircuitAccess = SonataCircuitAccess(simulation_config) self.simulation_access: SimulationAccess = SonataSimulationAccess( simulation_config ) else: self.circuit_access = BluepyCircuitAccess(simulation_config) self.simulation_access = BluepySimulationAccess(simulation_config) SimulationValidator(self.circuit_access).validate() self.dt = dt if dt is not None else (self.circuit_access.config.dt or 0.025) pc = ( parallel_context if parallel_context is not None else neuron.h.ParallelContext() ) self.pc = pc if int(pc.nhost()) > 1 or print_cellstate else None self.print_cellstate = print_cellstate self.save_time = save_time self.rng_settings = RNGSettings.get_instance() self.rng_settings.set_seeds( rng_mode, self.circuit_access.config, base_seed=base_seed, ) self.cells: CellDict = CellDict() self.gids_instantiated = False self.spike_threshold = self.circuit_access.config.spike_threshold self.spike_location = self.circuit_access.config.spike_location self.projections: list[str] | str | None = None condition_parameters = self.circuit_access.config.condition_parameters() set_global_condition_parameters(condition_parameters) self.gids: Optional[GidNamespace] = None self._instantiated_cells_mpi: set[CellId] | None = None
[docs] def instantiate_gids( self, cells: int | tuple[str, int] | list[int | tuple[str, int]], add_replay: bool = False, add_stimuli: bool = False, add_synapses: bool = False, add_minis: bool = False, add_noise_stimuli: bool = False, add_hyperpolarizing_stimuli: bool = False, add_relativelinear_stimuli: bool = False, add_pulse_stimuli: bool = False, add_projections: bool | list[str] | str = False, intersect_pre_gids: Optional[list] = None, interconnect_cells: bool = True, pre_spike_trains: None | dict[tuple[str, int], Iterable] | dict[int, Iterable] = None, add_shotnoise_stimuli: bool = False, add_ornstein_uhlenbeck_stimuli: bool = False, add_sinusoidal_stimuli: bool = False, add_linear_stimuli: bool = False, add_seclamp_stimuli: bool = False, add_subthreshold_stimuli: bool = False, ): """Instantiate a list of cells. Parameters ---------- cells : List of cell ids. When a single element, it will be converted to a list add_replay : Add presynaptic spiketrains from the large simulation If pre_spike_trains is combined with this option the spiketrains will be merged add_stimuli : Add the same stimuli as in the large simulation add_synapses : Add the touch-detected synapses, as described by the circuit to the cell (This option only influence the 'creation' of synapses, it doesn't add any connections) Default value is False add_minis : Add synaptic minis to the synapses (this requires add_synapses=True) Default value is False add_noise_stimuli : Process the 'noise' stimuli blocks of the simulation config, Setting add_stimuli=True, will automatically set this option to True. add_hyperpolarizing_stimuli : Process the 'hyperpolarizing' stimuli blocks of the simulation config. Setting add_stimuli=True, will automatically set this option to True. add_relativelinear_stimuli : Process the 'relativelinear' stimuli blocks of the simulation config. Setting add_stimuli=True, will automatically set this option to True. add_pulse_stimuli : Process the 'pulse' stimuli blocks of the simulation config. Setting add_stimuli=True, will automatically set this option to True. add_projections: Control whether projection edge populations are considered when adding synapses. * ``False`` (default): intrinsic connectivity only (no projection edge populations) * ``True``: intrinsic connectivity + all projection edge populations * ``list[str] | str``: Intrinsic connectivity plus the specified projection edge population name(s). Note: Names refer to SONATA edge population names (``SnapCircuit.edges`` keys). intersect_pre_gids : list of gids Only add synapses to the cells if their presynaptic gid is in this list interconnect_cells : When multiple gids are instantiated, interconnect the cells with real (non-replay) synapses. When this option is combined with add_replay, replay spiketrains will only be added for those presynaptic cells that are not in the network that's instantiated. This option requires add_synapses=True pre_spike_trains : A dictionary with keys the presynaptic gids, and values the list of spike timings of the presynaptic cells with the given gids. If this option is used in combination with add_replay=True, the spike trains for the same gids will be automatically merged add_shotnoise_stimuli : Process the 'shotnoise' stimuli blocks of the simulation config, Setting add_stimuli=True, will automatically set this option to True. add_ornstein_uhlenbeck_stimuli : Process the 'ornstein_uhlenbeck' stimuli blocks of the simulation config, Setting add_stimuli=True, will automatically set this option to True. add_sinusoidal_stimuli : Process the 'sinusoidal' stimuli blocks of the simulation config. Setting add_stimuli=True, will automatically set this option to True. add_linear_stimuli : Process the 'linear' stimuli blocks of the simulation config. Setting add_stimuli=True, will automatically set this option to True. add_seclamp_stimuli : Process the 'seclamp' stimuli blocks of the simulation config. Setting add_stimuli=True, will automatically set this option to True. add_subthreshold_stimuli : Process the 'subthreshold' stimuli blocks of the simulation config. Setting add_stimuli=True, will automatically set this option to True. """ if not isinstance(cells, list): cells = [cells] # convert to CellId objects cell_ids: list[CellId] = create_cell_ids(cells) if intersect_pre_gids is not None: pre_gids: Optional[list[CellId]] = create_cell_ids(intersect_pre_gids) else: pre_gids = None normalized_pre_spike_trains: dict[CellId, Iterable] | None = None if pre_spike_trains is not None: normalized_pre_spike_trains = { create_cell_id(gid): spikes for gid, spikes in pre_spike_trains.items() } if self.gids_instantiated: raise BluecellulabError( "instantiate_gids() is called twice on the " "same CircuitSimulation, this is not supported" ) else: self.gids_instantiated = True if normalized_pre_spike_trains or add_replay: if add_synapses is False: raise BluecellulabError( "You need to set add_synapses to True " "if you want to specify use add_replay or " "pre_spike_trains" ) if add_projections is True: self.projections = self.circuit_access.config.get_all_projection_names() elif add_projections is None or add_projections is False: self.projections = None else: self.projections = add_projections need_gids = ( add_synapses or (self.pc is not None) or add_replay or interconnect_cells or pre_spike_trains is not None or self.print_cellstate ) if need_gids: self.gids = self._build_gid_namespace(cell_ids) self._add_cells(cell_ids) self._apply_modifications() if add_synapses: self._add_synapses( pre_gids=pre_gids, add_minis=add_minis, ) if add_replay or interconnect_cells or normalized_pre_spike_trains: if add_replay and not add_synapses: raise BluecellulabError( "add_replay option can not be used if add_synapses is False" ) if self.pc is not None: self._init_instantiated_cells_mpi() self._register_gids_for_mpi() self.pc.barrier() self.pc.setup_transfer() self.pc.set_maxstep(1.0) self._add_connections( add_replay=add_replay, interconnect_cells=interconnect_cells, user_pre_spike_trains=normalized_pre_spike_trains, ) if add_stimuli: add_noise_stimuli = True add_hyperpolarizing_stimuli = True add_relativelinear_stimuli = True add_pulse_stimuli = True add_sinusoidal_stimuli = True add_shotnoise_stimuli = True add_ornstein_uhlenbeck_stimuli = True add_linear_stimuli = True add_seclamp_stimuli = True add_subthreshold_stimuli = True if ( add_noise_stimuli or add_hyperpolarizing_stimuli or add_pulse_stimuli or add_relativelinear_stimuli or add_shotnoise_stimuli or add_ornstein_uhlenbeck_stimuli or add_sinusoidal_stimuli or add_linear_stimuli or add_seclamp_stimuli or add_subthreshold_stimuli ): self._add_stimuli( add_noise_stimuli=add_noise_stimuli, add_hyperpolarizing_stimuli=add_hyperpolarizing_stimuli, add_relativelinear_stimuli=add_relativelinear_stimuli, add_pulse_stimuli=add_pulse_stimuli, add_shotnoise_stimuli=add_shotnoise_stimuli, add_ornstein_uhlenbeck_stimuli=add_ornstein_uhlenbeck_stimuli, add_sinusoidal_stimuli=add_sinusoidal_stimuli, add_linear_stimuli=add_linear_stimuli, add_seclamp_stimuli=add_seclamp_stimuli, add_subthreshold_stimuli=add_subthreshold_stimuli, ) self.recording_index, self.sites_index = prepare_recordings_for_reports( cells=self.cells, simulation_config=self.circuit_access.config, ) # add spike recordings for cell in self.cells.values(): if not cell.is_recording_spikes( self.spike_location, threshold=self.spike_threshold ): cell.start_recording_spikes( None, location=self.spike_location, threshold=self.spike_threshold )
def _add_stimuli( self, add_noise_stimuli=False, add_hyperpolarizing_stimuli=False, add_relativelinear_stimuli=False, add_pulse_stimuli=False, add_shotnoise_stimuli=False, add_ornstein_uhlenbeck_stimuli=False, add_sinusoidal_stimuli=False, add_linear_stimuli=False, add_seclamp_stimuli=False, add_subthreshold_stimuli=False, ) -> None: """Instantiate all the stimuli.""" stimuli_entries = self.circuit_access.config.get_all_stimuli_entries() # Also add the injections / stimulations as in the cortical model # check in which StimulusInjects the gid is a target # Every noise or shot noise stimulus gets a new seed noisestim_count = 0 shotnoise_stim_count = 0 ornstein_uhlenbeck_stim_count = 0 # Pre-fetch compartment sets (if available) to detect when a stimulus # target refers to a named compartment_set rather than a node_set. compartment_sets: dict[str, dict[str, Any]] | None = None try: compartment_sets = self.circuit_access.config.get_compartment_sets() except ValueError: pass for stimulus in stimuli_entries: # Build a unified list of (cell_id, section, segx, section_name) targets targets: list[tuple] = [] if stimulus.compartment_set is not None: targets = self._targets_from_compartment_set(stimulus, compartment_sets) elif stimulus.node_set is not None: gids_of_target = self.circuit_access.get_target_cell_ids( stimulus.node_set ) for cell_id in self.cells: if cell_id not in gids_of_target: continue sec = self.cells[cell_id].soma sec_name = sec.name().split(".")[-1] targets.append((cell_id, sec, 0.5, sec_name)) else: raise ValueError( f"Stimulus '{stimulus}' has neither node_set nor compartment_set; " "cannot resolve target location." ) for cell_id, sec, segx, sec_name in targets: if isinstance(stimulus, circuit_stimulus_definitions.Noise): if add_noise_stimuli: self.cells[cell_id].add_replay_noise( stimulus, noise_seed=None, noisestim_count=noisestim_count, section=sec, segx=segx, ) elif isinstance(stimulus, circuit_stimulus_definitions.Hyperpolarizing): if add_hyperpolarizing_stimuli: self.cells[cell_id].add_replay_hypamp( stimulus, section=sec, segx=segx ) elif isinstance(stimulus, circuit_stimulus_definitions.Pulse): if add_pulse_stimuli: self.cells[cell_id].add_pulse(stimulus, section=sec, segx=segx) elif isinstance(stimulus, circuit_stimulus_definitions.Linear): if add_linear_stimuli: self.cells[cell_id].add_replay_linear( stimulus, section=sec, segx=segx ) elif isinstance(stimulus, circuit_stimulus_definitions.RelativeLinear): if add_relativelinear_stimuli: self.cells[cell_id].add_replay_relativelinear( stimulus, section=sec, segx=segx ) elif isinstance(stimulus, circuit_stimulus_definitions.ShotNoise): if add_shotnoise_stimuli: self.cells[cell_id].add_replay_shotnoise( sec, segx, stimulus, shotnoise_stim_count=shotnoise_stim_count, ) elif isinstance( stimulus, circuit_stimulus_definitions.RelativeShotNoise ): if add_shotnoise_stimuli: self.cells[cell_id].add_replay_relative_shotnoise( sec, segx, stimulus, shotnoise_stim_count=shotnoise_stim_count, ) elif isinstance( stimulus, circuit_stimulus_definitions.OrnsteinUhlenbeck ): if add_ornstein_uhlenbeck_stimuli: self.cells[cell_id].add_ornstein_uhlenbeck( sec, segx, stimulus, stim_count=ornstein_uhlenbeck_stim_count, ) elif isinstance( stimulus, circuit_stimulus_definitions.RelativeOrnsteinUhlenbeck ): if add_ornstein_uhlenbeck_stimuli: self.cells[cell_id].add_relative_ornstein_uhlenbeck( sec, segx, stimulus, stim_count=ornstein_uhlenbeck_stim_count, ) elif isinstance(stimulus, circuit_stimulus_definitions.Sinusoidal): if add_sinusoidal_stimuli: self.cells[cell_id].add_sinusoidal(stimulus) elif isinstance(stimulus, circuit_stimulus_definitions.SEClamp): # sonata only if add_seclamp_stimuli: self.cells[cell_id].add_seclamp(stimulus, section=sec, segx=segx) elif isinstance(stimulus, circuit_stimulus_definitions.SubThreshold): if add_subthreshold_stimuli: self.cells[cell_id].add_replay_subthreshold( stimulus, section=sec, segx=segx ) elif isinstance( stimulus, circuit_stimulus_definitions.SynapseReplay ): # sonata only if self.circuit_access.target_contains_cell( stimulus.target, cell_id ): self.cells[cell_id].add_synapse_replay( stimulus, self.spike_threshold, self.spike_location ) else: raise ValueError( "Found stimulus with pattern %s, not supported" % stimulus ) logger.debug( f"Added {stimulus} to cell_id {cell_id} at {sec_name}({segx})" ) if isinstance(stimulus, Noise): noisestim_count += 1 elif isinstance(stimulus, (ShotNoise, RelativeShotNoise)): shotnoise_stim_count += 1 elif isinstance(stimulus, (OrnsteinUhlenbeck, RelativeOrnsteinUhlenbeck)): ornstein_uhlenbeck_stim_count += 1 def _add_synapses(self, pre_gids=None, add_minis=False): """Instantiate all the synapses.""" for cell_id in self.cells: self._add_cell_synapses( cell_id, pre_gids=pre_gids, add_minis=add_minis, ) def _add_cell_synapses( self, cell_id: CellId, pre_gids=None, add_minis=False ) -> None: syn_descriptions = self.get_syn_descriptions(cell_id) if pre_gids is not None: if self.circuit_format == CircuitFormat.SONATA: syn_descriptions = self._intersect_pre_gids_cell_ids_multipopulation( syn_descriptions, pre_gids ) else: syn_descriptions = self._intersect_pre_gids(syn_descriptions, pre_gids) # Check if there are any presynaptic cells, otherwise skip adding # synapses if syn_descriptions.empty: logger.warning( f"No presynaptic cells found for gid {cell_id}, no synapses added" ) else: for idx, syn_description in syn_descriptions.iterrows(): popids = ( syn_description["source_popid"], syn_description["target_popid"], ) self._instantiate_synapse( cell_id=cell_id, syn_id=idx, # type: ignore syn_description=syn_description, add_minis=add_minis, popids=popids, ) logger.info(f"Added {syn_descriptions} synapses for gid {cell_id}") if add_minis: logger.info(f"Added minis for {cell_id=}") def _targets_from_compartment_set( self, stimulus: circuit_stimulus_definitions.Stimulus, compartment_sets: dict[str, dict[str, Any]] | None, ) -> list[tuple]: """Resolve a compartment_set stimulus into (cell_id, section, segx, sec_name) targets.""" if compartment_sets is None: raise ValueError( "Simulation config provides compartment_set stimuli but " "no 'compartment_sets_file' is configured." ) comp_name = stimulus.compartment_set if comp_name not in compartment_sets: raise ValueError( f"Compartment set '{comp_name}' not found in compartment_sets file." ) comp_entry = compartment_sets[comp_name] comp_nodes = comp_entry.get("compartment_set", []) population_name = comp_entry.get("population") targets: list[tuple] = [] for cell_id in self.cells: if ( population_name is not None and getattr(cell_id, "population_name", None) != population_name ): continue try: resolved = self.cells[cell_id].resolve_segments_from_compartment_set( cell_id.id if hasattr(cell_id, "id") else cell_id, comp_nodes, ) except (ValueError, TypeError) as e: logger.debug( f"Failed to resolve compartment_set for cell {cell_id}: {e}" ) continue for sec, sec_name, segx in resolved: targets.append((cell_id, sec, segx, sec_name)) return targets @staticmethod def _intersect_pre_gids(syn_descriptions, pre_gids: list[CellId]) -> pd.DataFrame: """Return the synapse descriptions with pre_gids intersected.""" _pre_gids = {x.id for x in pre_gids} return syn_descriptions[ syn_descriptions[SynapseProperty.PRE_GID].isin(_pre_gids) ] @staticmethod def _intersect_pre_gids_cell_ids_multipopulation( syn_descriptions, pre_cell_ids: list[CellId] ) -> pd.DataFrame: """Return the synapse descriptions with pre_cell_ids intersected. Supports multipopulations. """ filtered_rows = syn_descriptions.apply( lambda row: any( cell.population_name == row["source_population_name"] and row[SynapseProperty.PRE_GID] == cell.id for cell in pre_cell_ids ), axis=1, ) return syn_descriptions[filtered_rows]
[docs] def get_syn_descriptions(self, cell_id: int | tuple[str, int]) -> pd.DataFrame: """Get synapse descriptions dataframe.""" cell_id = create_cell_id(cell_id) return self.circuit_access.extract_synapses( cell_id, projections=self.projections )
[docs] @staticmethod def merge_pre_spike_trains(*train_dicts) -> dict[CellId, np.ndarray]: """Merge presynaptic spike train dicts.""" filtered_dicts = [d for d in train_dicts if isinstance(d, dict) and d] if not filtered_dicts: return {} all_keys = set().union(*[d.keys() for d in filtered_dicts]) result = {} for k in all_keys: valid_arrays = [] for d in filtered_dicts: if k in d: val = d[k] if isinstance(val, (np.ndarray, list)) and len(val) > 0: valid_arrays.append(np.asarray(val)) if valid_arrays: result[k] = np.sort(np.concatenate(valid_arrays)) return result
def _find_matching_override(self, overrides, pre: CellId, post: CellId): matched = None for ov in overrides: # ov.source and ov.target are nodeset names if self.circuit_access.target_contains_cell( ov.source, pre ) and self.circuit_access.target_contains_cell(ov.target, post): matched = ov # "last match wins" like Neurodamus ordering return matched def _add_connections( self, add_replay=None, interconnect_cells=None, user_pre_spike_trains: None | dict[CellId, Iterable] = None, ) -> None: """Instantiate the (replay and real) connections in the network.""" pre_spike_trains = self.simulation_access.get_spikes() if add_replay else {} pre_spike_trains = self.merge_pre_spike_trains( pre_spike_trains, user_pre_spike_trains, ) connections_overrides = ( self.circuit_access.config.connection_entries() if hasattr(self, "circuit_access") else [] ) for post_gid in self.cells: for syn_id in self.cells[post_gid].synapses: synapse = self.cells[post_gid].synapses[syn_id] syn_description: pd.Series = synapse.syn_description delay_weights = synapse.delay_weights source_population = syn_description["source_population_name"] pre_local_id = CellId(source_population, int(syn_description[SynapseProperty.PRE_GID])) ov = self._find_matching_override(connections_overrides, pre_local_id, post_gid) if ov is not None and ov.weight == 0.0: logger.debug( "Skipping connection due to zero weight override: %s -> %s | syn_id=%s", pre_local_id, post_gid, syn_id ) continue if self.pc is None: real_synapse_connection = bool(interconnect_cells) and ( pre_local_id in self.cells ) else: real_synapse_connection = bool(interconnect_cells) and ( self._instantiated_cells_mpi is not None and pre_local_id in self._instantiated_cells_mpi ) if real_synapse_connection: if ( user_pre_spike_trains is not None and pre_local_id in user_pre_spike_trains ): raise BluecellulabError( """Specifying prespike trains of real connections""" """ is not allowed.""" ) if self.pc is None: # serial only connection = bluecellulab.Connection( self.cells[post_gid].synapses[syn_id], pre_spiketrain=None, pre_cell=self.cells[pre_local_id], stim_dt=self.dt, parallel_context=None, spike_threshold=self.spike_threshold, spike_location=self.spike_location, ) else: # MPI cross-rank pre_gid = self.global_gid(pre_local_id.population_name, pre_local_id.id) connection = bluecellulab.Connection( self.cells[post_gid].synapses[syn_id], pre_spiketrain=None, pre_gid=pre_gid, pre_cell=None, stim_dt=self.dt, parallel_context=self.pc, spike_threshold=self.spike_threshold, spike_location=self.spike_location, ) logger.debug(f"Added real connection between {pre_local_id} and {post_gid}, {syn_id}") else: # replay connection pre_spiketrain = pre_spike_trains.get(pre_local_id, None) connection = bluecellulab.Connection( self.cells[post_gid].synapses[syn_id], pre_spiketrain=pre_spiketrain, pre_cell=None, stim_dt=self.dt, parallel_context=None, spike_threshold=self.spike_threshold, spike_location=self.spike_location, ) logger.debug(f"Added replay connection from {pre_local_id} to {post_gid}, {syn_id}") if ov is not None: logger.debug( "Override matched: %s -> %s | syn_id=%s | weight=%s delay=%s", pre_local_id, post_gid, syn_id, ov.weight, ov.delay ) syn_delay = getattr(ov, "synapse_delay_override", None) if syn_delay is not None: connection.set_netcon_delay(float(syn_delay)) logger.debug( "Applied synapse_delay_override %.4g ms to %s -> %s | syn_id=%s", syn_delay, pre_local_id, post_gid, syn_id ) if ov.delay is not None: logger.warning( "SONATA override 'delay' (delayed weight activation) is not supported yet; " "applying weight immediately. %s -> %s | syn_id=%s | delay=%s", pre_local_id, post_gid, syn_id, ov.delay ) if ov.weight is not None: connection.set_weight_scalar(float(ov.weight)) logger.debug( "Applied weight override factor %.4g to %s -> %s | syn_id=%s | final_weight=%.4g", ov.weight, pre_local_id, post_gid, syn_id, connection.post_netcon_weight ) self.cells[post_gid].connections[syn_id] = connection for delay, weight_scale in delay_weights: self.cells[post_gid].add_replay_delayed_weight( syn_id, delay, weight_scale * connection.weight ) if len(self.cells[post_gid].connections) > 0: logger.debug(f"Added synaptic connections for target {post_gid}") def _add_cells(self, cell_ids: list[CellId]) -> None: """Instantiate cells from a gid list.""" self.cells = CellDict() for cell_id in cell_ids: cell = self.create_cell_from_circuit(cell_id) cell.post_gid = self.global_gid(cell_id.population_name, cell_id.id) self.cells[cell_id] = cell if self.circuit_access.node_properties_available: cell.connect_to_circuit(SonataProxy(cell_id, self.circuit_access)) def _apply_modifications(self) -> None: """Apply condition modifications from the simulation config to cells.""" try: modifications = self.circuit_access.config.get_modifications() except (NotImplementedError, AttributeError): return if modifications: apply_modifications(self.cells, modifications, self.circuit_access) def _instantiate_synapse(self, cell_id: CellId, syn_id: SynapseID, syn_description, add_minis=False, popids=(0, 0)) -> None: """Instantiate one synapse for a given gid, syn_id and syn_description.""" pre_cell_id = CellId( syn_description["source_population_name"], int(syn_description[SynapseProperty.PRE_GID]), ) syn_connection_parameters = get_synapse_connection_parameters( circuit_access=self.circuit_access, pre_cell=pre_cell_id, post_cell=cell_id, ) if syn_connection_parameters["add_synapse"]: condition_parameters = self.circuit_access.config.condition_parameters() self.cells[cell_id].add_replay_synapse( syn_id, syn_description, syn_connection_parameters, condition_parameters, popids=popids, extracellular_calcium=self.circuit_access.config.extracellular_calcium) if add_minis: mini_frequencies = self.circuit_access.fetch_mini_frequencies(cell_id) logger.debug( f"Adding minis for synapse {syn_id}: syn_description={syn_description}, connection={syn_connection_parameters}, frequency={mini_frequencies}" ) self.cells[cell_id].add_replay_minis( syn_id, syn_description, syn_connection_parameters, popids=popids, mini_frequencies=mini_frequencies, )
[docs] def run( self, t_stop: Optional[float] = None, v_init: Optional[float] = None, celsius: Optional[float] = None, dt: Optional[float] = None, forward_skip: bool = True, forward_skip_value: Optional[float] = None, cvode: bool = False, show_progress: bool = False, ): """Simulate the Circuit. Parameters ---------- t_stop : This function will run the simulation until t_stop v_init : Voltage initial value when the simulation starts celsius : Temperature at which the simulation runs dt : Timestep (delta-t) for the simulation forward_skip : [compatibility/non-sonata] Enable/disable ForwardSkip, when forward_skip_value is None, forward skip will only be enabled if the simulation config has a ForwardSkip value) forward_skip_value : [compatibility/non-sonata] Overwrite the ForwardSkip value in the simulation config. If this is set to None, the value in the simulation config is used. cvode : Force the simulation to run in variable timestep. Not possible when there are stochastic channels in the neuron model. When enabled results from a large network simulation will not be exactly reproduced. show_progress: Show a progress bar during simulations. When enabled results from a large network simulation will not be exactly reproduced. Note ---- Passing ``dt`` to ``run()`` is deprecated and will be removed in a future release. The simulation timestep used is the one resolved at :class:`CircuitSimulation` construction (either the explicit ``dt`` passed to the constructor or the value from the ``simulation_config``). """ if t_stop is None: t_stop = self.circuit_access.config.tstop if t_stop is None: # type narrowing t_stop = 0.0 if dt is not None and dt != self.dt: warnings.warn( "Passing `dt` to `run()` to change the simulation timestep is deprecated; " "pass `dt` to the CircuitSimulation constructor instead. " "This behavior will be removed in a future release.", DeprecationWarning, ) dt = self.dt config_forward_skip_value = self.circuit_access.config.forward_skip # legacy config_tstart = self.circuit_access.config.tstart or 0.0 # SONATA # Determine effective skip value and flag if forward_skip_value is not None: # User explicitly provided value → use it effective_skip_value = forward_skip_value effective_skip = forward_skip elif config_forward_skip_value is not None: # Use legacy config if available effective_skip_value = config_forward_skip_value effective_skip = forward_skip elif config_tstart > 0.0: # Use SONATA tstart *only* if no other skip value was provided effective_skip_value = config_tstart effective_skip = True else: # No skip effective_skip_value = None effective_skip = False if celsius is None: celsius = self.circuit_access.config.celsius NeuronGlobals.get_instance().temperature = celsius if v_init is None: v_init = self.circuit_access.config.v_init NeuronGlobals.get_instance().v_init = v_init sim = bluecellulab.Simulation(self.pc) for cell_id in self.cells: sim.add_cell(self.cells[cell_id]) self.fih_prcellstate = None if self.pc is not None and self.print_cellstate: def dump(): for cell in self.cells: pop = cell.population_name gid = cell.id g = self.global_gid(pop, gid) self.pc.prcellstate(g, f"bluecellulab_t={neuron.h.t}") def schedule_dump(): t_dump = ( self.save_time if self.save_time is not None else self.circuit_access.config.tstop ) neuron.h.cvode.event(t_dump, dump) self.fih_prcellstate = neuron.h.FInitializeHandler(1, schedule_dump) if show_progress: logger.warning( "show_progress enabled, this will very likely" "break the exact reproducibility of large network" "simulations" ) sim.run( tstop=t_stop, cvode=cvode, dt=dt, forward_skip=effective_skip, forward_skip_value=effective_skip_value, show_progress=show_progress, )
[docs] def get_mainsim_voltage_trace( self, cell_id: int | tuple[str, int], t_start=None, t_stop=None, t_step=None ) -> np.ndarray: """Get the voltage trace from a cell from the main simulation. Parameters ----------- cell_id: cell id of interest. t_start, t_stop: time range of interest, report time range is used by default. t_step: time step (should be a multiple of report time step T; equals T by default) Returns: One dimentional np.ndarray to represent the voltages. """ cell_id = create_cell_id(cell_id) return self.simulation_access.get_soma_voltage(cell_id, t_start, t_stop, t_step)
[docs] def get_mainsim_time_trace(self, t_step=None) -> np.ndarray: """Get the time trace from the main simulation. Parameters ----------- t_step: time step (should be a multiple of report time step T; equals T by default) Returns: One dimentional np.ndarray to represent the times. """ return self.simulation_access.get_soma_time_trace(t_step)
[docs] def get_time(self) -> np.ndarray: """Get the time vector for the recordings, contains negative times. The negative times occur as a result of ForwardSkip. """ first_key = next(iter(self.cells)) return self.cells[first_key].get_time()
[docs] def get_time_trace(self, t_start=None, t_stop=None, t_step=None) -> np.ndarray: """Get the time vector for the recordings, negative times removed. Parameters ----------- t_start, t_stop: time range of interest. t_step: time step (multiple of report dt; equals dt by default) Returns: 1D np.ndarray representing time points. """ time = self.get_time() time = time[time >= 0.0] if t_start is None or t_start < 0: t_start = 0 if t_stop is None: t_stop = np.inf time = time[(time >= t_start) & (time <= t_stop)] if t_step is not None: ratio = t_step / self.dt time = _sample_array(time, ratio) return time
[docs] def get_voltage_trace( self, cell_id: int | tuple[str, int], t_start=None, t_stop=None, t_step=None ) -> np.ndarray: """Get the voltage vector for the cell_id, negative times removed. Parameters ----------- cell_id: cell id of interest. t_start, t_stop: time range of interest, report time range is used by default. t_step: time step (should be a multiple of report time step T; equals T by default) Returns: One dimentional np.ndarray to represent the voltages. """ cell_id = create_cell_id(cell_id) time = self.get_time() voltage = self.cells[cell_id].get_soma_voltage() if t_start is None or t_start < 0: t_start = 0 if t_stop is None: t_stop = np.inf voltage = voltage[np.where((time >= t_start) & (time <= t_stop))] if t_step is not None: ratio = t_step / self.dt voltage = _sample_array(voltage, ratio) return voltage
[docs] def delete(self): """Delete CircuitSimulation and all of its attributes. NEURON objects are explicitly needed to be deleted. """ if hasattr(self, "cells"): for _, cell in self.cells.items(): cell.delete() cell_ids = list(self.cells.keys()) for cell_id in cell_ids: del self.cells[cell_id]
def __del__(self): """Destructor. Deletes all allocated NEURON objects. """ self.delete()
[docs] def fetch_cell_kwargs(self, cell_id: CellId) -> dict: """Get the kwargs to instantiate a Cell object.""" emodel_properties = self.circuit_access.get_emodel_properties(cell_id) cell_kwargs = { "template_path": self.circuit_access.emodel_path(cell_id), "morphology_path": self.circuit_access.morph_filepath(cell_id), "cell_id": cell_id, "record_dt": self.record_dt, "template_format": self.circuit_access.get_template_format(), "emodel_properties": emodel_properties, } return cell_kwargs
[docs] def create_cell_from_circuit(self, cell_id: CellId) -> bluecellulab.Cell: """Create a Cell object from the circuit.""" cell_kwargs = self.fetch_cell_kwargs(cell_id) return bluecellulab.Cell( template_path=cell_kwargs["template_path"], morphology_path=cell_kwargs["morphology_path"], cell_id=cell_kwargs["cell_id"], record_dt=cell_kwargs["record_dt"], template_format=cell_kwargs["template_format"], emodel_properties=cell_kwargs["emodel_properties"], )
[docs] def global_gid(self, pop: str, local_id: int) -> int: """Convert a population name and local ID to a global GID. Parameters ---------- pop : str Population name. local_id : int Local ID within the population. Returns ------- int Global GID. Raises ------ RuntimeError If GID namespace is not initialized. """ if self.gids is None: raise RuntimeError("GID namespace not initialized yet.") return self.gids.global_gid(pop, local_id)
def _build_gid_namespace(self, cell_ids: list[CellId]) -> GidNamespace: local_max_raw: dict[str, int] = {} # real populations: local loaded max for cell_id in cell_ids: pop = cell_id.population_name gid = int(cell_id.id) local_max_raw[pop] = max(local_max_raw.get(pop, -1), gid) # virtual populations: full size for pop, size in self.circuit_access.virtual_population_sizes().items(): local_max_raw[pop] = max(local_max_raw.get(pop, -1), int(size) - 1) if self.pc is not None: gathered = self.pc.py_gather(local_max_raw, 0) if int(self.pc.id()) == 0: global_max_raw: dict[str, int] = {} for d in gathered: for pop, gid in d.items(): global_max_raw[pop] = max(global_max_raw.get(pop, -1), gid) pops_sorted = sorted(global_max_raw.keys()) pop_offset = self._compute_offsets_from_max(pops_sorted, global_max_raw) else: pop_offset = None pop_offset = self.pc.py_broadcast(pop_offset, 0) return GidNamespace(pop_offset) pops_sorted = sorted(local_max_raw.keys()) pop_offset = self._compute_offsets_from_max(pops_sorted, local_max_raw) return GidNamespace(pop_offset) def _register_gids_for_mpi(self) -> None: assert self.pc is not None assert self.gids is not None for cell_id, cell in self.cells.items(): g = self.global_gid(cell_id.population_name, cell_id.id) self.pc.set_gid2node(g, int(self.pc.id())) nc = cell.create_netcon_spikedetector(None, location=self.spike_location, threshold=self.spike_threshold) self.pc.cell(g, nc) def _compute_offsets_from_max(self, pops_sorted: list[str], max_raw: dict[str, int]) -> dict[str, int]: pop_offset: dict[str, int] = {} prev: str | None = None for p in pops_sorted: if prev is None: pop_offset[p] = 0 else: prev_count = int(max_raw[prev]) + 1 end_prev = pop_offset[prev] + prev_count pop_offset[p] = ((end_prev + 999) // 1000) * 1000 prev = p return pop_offset def _init_instantiated_cells_mpi(self) -> None: """Build the global set of instantiated CellIds across all MPI ranks.""" assert self.pc is not None local_cells = list(self.cells.keys()) gathered = self.pc.py_gather(local_cells, 0) if int(self.pc.id()) == 0: all_cells = set() for cell_list in gathered: all_cells.update(cell_list) else: all_cells = None self._instantiated_cells_mpi = self.pc.py_broadcast(all_cells, 0)