"""Structured feature extraction from PFASGroups halogen group detection.
This module provides :func:`extract_group_features`, a single public entry
point that runs the necessary :func:`~PFASGroups.parser.parse_mol` calls and
returns a :class:`GroupFeatureResult` with four named feature dictionaries
covering the following four group categories:
- **Polyhalogenated groups** (ids 35, 38, 45): aggregate match counts
- **Perhalogenated groups** (ids 34, 37, 44): per-halogen max component sizes
- **Alkyl chain (H pseudo-halogen)** aggregates: convenience view of the H
column of perhalogenated results
- **Generic functional groups** (ids 29-76): wildcard bag-of-groups counts
The six component groups (34/35/37/38/44/45) are excluded from wildcard
matching by design, so they never appear in the generic-group counts.
Example
-------
>>> from rdkit import Chem
>>> from PFASGroups import extract_group_features
>>> mol = Chem.MolFromSmiles("FC(F)(F)C(F)(F)C(F)(F)C(=O)O") # PFOA-like
>>> r = extract_group_features(mol)
>>> r.poly_counts
{'poly_alkyl': 1.0, 'poly_aryl': 0.0, 'poly_cyclic': 0.0}
>>> r.per_halogen_sizes['g34_F'] # longest perfluoroalkyl component (carbons)
4.0
>>> r.h_chain_sizes
{'alkyl_H': 0.0, 'aryl_H': 0.0, 'cyclic_H': 0.0}
>>> len(r.to_array()) # 3 + 15 + 48
66
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Union
import numpy as np
from rdkit import Chem
from .parser import parse_mol as _parse_mol # type: ignore[attr-defined]
from .getter import get_compiled_HalogenGroups as _get_compiled_HalogenGroups
# ---------------------------------------------------------------------------
# Public constants
# ---------------------------------------------------------------------------
#: Group IDs for perhalogenated groups (alkyl, aryl, cyclic).
#: These groups require *all* halogen-bearing carbons to carry only one halogen
#: type. Each match is attributed to a single halogen (or H for alkyl chains).
PER_GROUP_IDS: list[int] = [34, 37, 44]
#: Group IDs for polyhalogenated groups (alkyl, aryl, cyclic).
#: These groups allow a *mixture* of halogens and/or partial halogenation.
#: Match counts are aggregated across all halogens.
POLY_GROUP_IDS: list[int] = [35, 38, 45]
#: Halogen types recognised in :attr:`GroupFeatureResult.per_halogen_sizes`.
#: ``'H'`` is the pseudo-halogen used to detect un-substituted alkyl chains
#: (perhalogenated-alkyl, group 34) via :class:`~PFASGroups.ComponentsSolverModel.ComponentsSolver`.
HALOGENS_ORDER: list[str] = ["F", "Cl", "Br", "I", "H"]
#: Ordered list of group IDs covered by :attr:`GroupFeatureResult.generic_groups`.
#: Ids {34, 35, 37, 38, 44, 45} are excluded from wildcard matching and will
#: always be zero in this vector.
GENERIC_GROUP_VOCAB: list[int] = list(range(29, 77)) # 48 entries
#: Human-readable names for every group id in :data:`GENERIC_GROUP_VOCAB`.
GENERIC_GROUP_NAMES: dict[int, str] = {
29: "acrylate",
30: "acyl halide",
31: "alcohol",
32: "aldehyde",
33: "alkene",
34: "perhalogenated alkyl",
35: "polyhalogenated alkyl",
36: "alkyne",
37: "perhalogenated aryl compounds",
38: "polyhalogenated aryl compounds",
39: "benzodioxole",
40: "benzoyl peroxydes",
41: "bromide",
42: "carboxylic acid",
43: "chloride",
44: "perhalogenated cyclic compounds",
45: "polyhalogenated cyclic compounds",
46: "ester",
47: "ether",
48: "fluoride",
49: "glucuronate",
50: "iodide",
51: "ketone",
52: "methacrylate",
53: "peroxydes",
54: "side-chain aromatics",
55: "sulfenic acid",
56: "sulfenyl halide",
57: "sulfinic acid",
58: "sulfinyl amido sulfonic acid",
59: "sulfonamide",
60: "sulfonamidoethanol",
61: "sulfonic acid",
62: "sulfonyl halide",
63: "sulfonyl propanoic acid",
64: "sulfuric acid",
65: "thioester keto dicarboxylic acid",
66: "thiocyanic acid",
67: "phosphinic acid",
68: "phosphonic acid",
69: "amide",
70: "amine",
71: "heterocyclic azine",
72: "heterocyclic azole",
73: "betaine",
74: "glycine",
75: "trichlorosilane",
76: "silane",
}
# ---------------------------------------------------------------------------
# Internal state
# ---------------------------------------------------------------------------
_COMPONENT_GROUPS: list | None = None # cached list of HalogenGroups {34..45}
_GENERIC_GROUP_INDEX: dict[int, int] = {gid: i for i, gid in enumerate(GENERIC_GROUP_VOCAB)}
_HAL_TO_COL: dict[str, int] = {h: i for i, h in enumerate(HALOGENS_ORDER)}
_PER_GID_TO_ROW: dict[int, int] = {34: 0, 37: 1, 44: 2}
_POLY_GID_TO_IDX: dict[int, int] = {35: 0, 38: 1, 45: 2}
# Short labels used in h_chain_sizes
_H_CHAIN_KEYS: list[str] = ["alkyl_H", "aryl_H", "cyclic_H"]
# Corresponding per_halogen_sizes keys (H column of each per group)
_H_CHAIN_PHS_KEYS: list[str] = ["g34_H", "g37_H", "g44_H"]
def _get_component_groups() -> list:
"""Return a cached list of :class:`~PFASGroups.HalogenGroup` objects with ids in ``{34,35,37,38,44,45}``."""
global _COMPONENT_GROUPS
if _COMPONENT_GROUPS is None:
_ids = {34, 35, 37, 38, 44, 45}
_COMPONENT_GROUPS = [g for g in _get_compiled_HalogenGroups() if g.id in _ids]
return _COMPONENT_GROUPS
# ---------------------------------------------------------------------------
# GroupFeatureResult
# ---------------------------------------------------------------------------
[docs]
@dataclass
class GroupFeatureResult:
"""Structured feature extraction result for a single molecule.
This dataclass groups features from two :func:`~PFASGroups.parser.parse_mol`
calls into four semantically distinct dictionaries. It is returned by
:func:`extract_group_features`.
Attributes
----------
poly_counts : dict[str, float]
Match counts for **polyhalogenated** groups (ids 35, 38, 45), aggregated
across all halogens (F, Cl, Br, I, H). Keys: ``'poly_alkyl'``,
``'poly_aryl'``, ``'poly_cyclic'``.
per_halogen_sizes : dict[str, float]
Maximum carbon-component size for **perhalogenated** groups (ids 34, 37,
44) resolved per halogen. Keys follow the pattern ``'g{id}_{hal}'``
where *id* is in ``{34, 37, 44}`` and *hal* is in
``['F', 'Cl', 'Br', 'I', 'H']``. Zero when no match is found.
h_chain_sizes : dict[str, float]
Convenience view of the ``'H'`` column of :attr:`per_halogen_sizes`:
the largest un-substituted alkyl / aryl / cyclic component detected via
the H pseudo-halogen mechanism. Keys: ``'alkyl_H'``, ``'aryl_H'``,
``'cyclic_H'``. **Not included in** :meth:`to_array`.
generic_groups : dict[str, float]
Wildcard functional-group match counts for group ids 29-76. Keys:
``'g29'`` … ``'g76'``. Ids ``{34, 35, 37, 38, 44, 45}`` are excluded
from wildcard matching and will always be zero.
"""
poly_counts: dict[str, float] = field(default_factory=dict)
per_halogen_sizes: dict[str, float] = field(default_factory=dict)
h_chain_sizes: dict[str, float] = field(default_factory=dict)
generic_groups: dict[str, float] = field(default_factory=dict)
[docs]
def to_array(self) -> np.ndarray:
"""Return a fixed-length float32 array of shape ``(66,)``.
Layout::
[0:3] poly_counts (poly_alkyl, poly_aryl, poly_cyclic)
[3:18] per_halogen_sizes (g34_F…g44_H, 3 groups × 5 halogens)
[18:66] generic_groups (g29…g76, 48 entries)
Note
----
:attr:`h_chain_sizes` is excluded from this array (it is a strict
subset of :attr:`per_halogen_sizes`).
"""
poly = np.array(
[self.poly_counts[k] for k in ("poly_alkyl", "poly_aryl", "poly_cyclic")],
dtype=np.float32,
)
per_keys = [
f"g{gid}_{h}"
for gid in PER_GROUP_IDS
for h in HALOGENS_ORDER
]
per = np.array([self.per_halogen_sizes[k] for k in per_keys], dtype=np.float32)
gen = np.array(
[self.generic_groups[f"g{gid}"] for gid in GENERIC_GROUP_VOCAB],
dtype=np.float32,
)
return np.concatenate([poly, per, gen])
[docs]
def feature_names(self) -> list[str]:
"""Return the 66 feature names in the same order as :meth:`to_array`.
Returns
-------
list[str]
Labels: ``['poly_alkyl', 'poly_aryl', 'poly_cyclic',
'g34_F', 'g34_Cl', …, 'g44_H', 'g29', 'g30', …, 'g76']``.
"""
poly_names = ["poly_alkyl", "poly_aryl", "poly_cyclic"]
per_names = [
f"g{gid}_{h}"
for gid in PER_GROUP_IDS
for h in HALOGENS_ORDER
]
gen_names = [f"g{gid}" for gid in GENERIC_GROUP_VOCAB]
return poly_names + per_names + gen_names
def __repr__(self) -> str:
nz_poly = sum(1 for v in self.poly_counts.values() if v)
nz_per = sum(1 for v in self.per_halogen_sizes.values() if v)
nz_gen = sum(1 for v in self.generic_groups.values() if v)
return (
f"GroupFeatureResult("
f"poly_counts={nz_poly} nonzero, "
f"per_halogen_sizes={nz_per} nonzero, "
f"generic_groups={nz_gen} nonzero)"
)
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------