SE3 Molecular Flow¤
Module: generative_models.models.flow.se3_molecular
Source: generative_models/models/flow/se3_molecular.py
Overview¤
SE3MolecularFlow is a simplified molecular conformation flow baseline built
from geometry-conditioned helper layers and affine coordinate transforms.
The current runtime is useful as a molecular flow baseline, but it does not provide verified global rotation or translation guarantees.
Public Surface¤
SE3MolecularFlow¤
Module-local molecular flow model for conformation scoring and sampling.
Retained behavior:
log_prob(coordinates, atom_types, atom_mask)scores molecular coordinates under the current affine-coupling flowsample(atom_types, atom_mask, num_samples, *, rngs)draws molecular coordinate samples for the provided atom templategenerate(n_samples=1, *, rngs, atom_types=..., atom_mask=...)wraps the sampling path required by the shared generative-model interfaceloss_fn(batch, model_outputs, **kwargs)returns the canonical flow loss dictionary withtotal_loss,nll_loss,log_prob, andavg_log_prob
Helper classes in this module remain implementation details for the retained molecular flow runtime.
Example¤
from flax import nnx
import jax.numpy as jnp
from artifex.generative_models.models.flow.se3_molecular import SE3MolecularFlow
model = SE3MolecularFlow(
hidden_dim=64,
num_layers=3,
num_coupling_layers=4,
max_atoms=29,
atom_types=5,
rngs=nnx.Rngs(0),
)
atom_types = jnp.ones((2, 29), dtype=jnp.int32)
atom_mask = jnp.ones((2, 29), dtype=jnp.bool_)
samples = model.generate(2, rngs=nnx.Rngs(1), atom_types=atom_types, atom_mask=atom_mask)