Source code for neo4j_graphrag.experimental.components.resolver

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from __future__ import annotations

import abc
import logging
from itertools import combinations
from typing import Any, List, Optional, TYPE_CHECKING


try:
    from rapidfuzz import fuzz
    from rapidfuzz import utils

    IS_RAPIDFUZZ_INSTALLED = True
except ImportError:
    IS_RAPIDFUZZ_INSTALLED = False

try:
    import spacy
    from spacy.cli.download import download as spacy_download
    from spacy.language import Language
    import numpy as np

    IS_SPACY_INSTALLED = True
except ImportError:
    IS_SPACY_INSTALLED = False


if TYPE_CHECKING:
    import numpy as np
    from numpy.typing import NDArray

import neo4j
from neo4j_graphrag.experimental.components.types import ResolutionStats
from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.experimental.pipeline.component import ComponentMeta
from neo4j_graphrag.utils import driver_config

logger = logging.getLogger(__name__)


class EntityResolver(Component):
    """Entity resolution base class

    Args:
        driver (neo4j.Driver): The Neo4j driver to connect to the database.
        filter_query (Optional[str]): Cypher query to select the entities to resolve. By default, all nodes with __Entity__ label are used
    """

    def __init__(
        self,
        driver: neo4j.Driver,
        filter_query: Optional[str] = None,
    ) -> None:
        self.driver = driver_config.override_user_agent(driver)
        self.filter_query = filter_query

    async def run(self, *args: Any, **kwargs: Any) -> ResolutionStats:
        raise NotImplementedError()


[docs] class SinglePropertyExactMatchResolver(EntityResolver): """Resolve entities with same label and exact same property (default is "name"). Args: driver (neo4j.Driver): The Neo4j driver to connect to the database. filter_query (Optional[str]): To reduce the resolution scope, add a Cypher WHERE clause. resolve_property (str): The property that will be compared (default: "name"). If values match exactly, entities are merged. neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_). Example: .. code-block:: python from neo4j import GraphDatabase from neo4j_graphrag.experimental.components.resolver import SinglePropertyExactMatchResolver URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") DATABASE = "neo4j" driver = GraphDatabase.driver(URI, auth=AUTH) resolver = SinglePropertyExactMatchResolver(driver=driver, neo4j_database=DATABASE) await resolver.run() # no expected parameters """ def __init__( self, driver: neo4j.Driver, filter_query: Optional[str] = None, resolve_property: str = "name", neo4j_database: Optional[str] = None, ) -> None: super().__init__(driver, filter_query) self.resolve_property = resolve_property self.neo4j_database = neo4j_database
[docs] async def run(self) -> ResolutionStats: """Resolve entities based on the following rule: For each entity label, entities with the same 'resolve_property' value (exact match) are grouped into a single node: - Properties: the property from the first node will remain if already set, otherwise the first property in list will be written. - Relationships: merge relationships with same type and target node. See apoc.refactor.mergeNodes documentation for more details. """ match_query = "MATCH (entity:__Entity__) " if self.filter_query: match_query += self.filter_query stat_query = f"{match_query} RETURN count(entity) as c" records, _, _ = self.driver.execute_query( stat_query, database_=self.neo4j_database, ) number_of_nodes_to_resolve = records[0].get("c") if number_of_nodes_to_resolve == 0: return ResolutionStats( number_of_nodes_to_resolve=0, ) merge_nodes_query = ( f"{match_query} " f"WITH entity, entity.{self.resolve_property} as prop " # keep only entities for which the resolve_property (name) is not null "WITH entity, prop WHERE prop IS NOT NULL " # will check the property for each of the entity labels, # except the reserved ones __Entity__ and __KGBuilder__ "UNWIND labels(entity) as lab " "WITH lab, prop, entity WHERE NOT lab IN ['__Entity__', '__KGBuilder__'] " # aggregate based on property value and label # collect all entities with exact same property and label # in the 'entities' list "WITH prop, lab, collect(entity) AS entities " # merge all entities into a single node # * merge relationships: if the merged entities have a relationship of same # type to the same target node, these relationships are merged # otherwise relationships are just attached to the newly created node # * properties: if the two entities have the same property key with # different values, only one of them is kept in the created node "CALL apoc.refactor.mergeNodes(entities,{ " " properties:'discard', " " mergeRels:true " "}) " "YIELD node " "RETURN count(node) as c " ) records, _, _ = self.driver.execute_query( merge_nodes_query, database_=self.neo4j_database ) number_of_created_nodes = records[0].get("c") return ResolutionStats( number_of_nodes_to_resolve=number_of_nodes_to_resolve, number_of_created_nodes=number_of_created_nodes, )
class CombinedMeta(ComponentMeta, abc.ABCMeta): """ A metaclass that merges ComponentMeta (from Component) and ABCMeta (from abc.ABC). """ pass class BasePropertySimilarityResolver(EntityResolver, abc.ABC, metaclass=CombinedMeta): """ Base class for similarity-based matching of properties for entity resolution. Resolve entities with same label and similar set of textual properties (default is ["name"]): - Group entities by label - Concatenate the specified textual properties - Compute similarity between each pair - Consolidate overlapping sets - Merge similar nodes via APOC (See apoc.refactor.mergeNodes documentation for more details). Subclasses implement `compute_similarity` based on different strategies, and return a similarity score between 0 and 1. Args: driver (neo4j.Driver): The Neo4j driver to connect to the database. filter_query (Optional[str]): Optional Cypher WHERE clause to reduce the resolution scope. resolve_properties (Optional[List[str]]): The list of properties to consider for similarity. Defaults to ["name"]. similarity_threshold (float): The similarity threshold above which nodes are merged. Defaults to 0.8. neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_). """ def __init__( self, driver: neo4j.Driver, filter_query: Optional[str] = None, resolve_properties: Optional[List[str]] = None, similarity_threshold: float = 0.8, neo4j_database: Optional[str] = None, ) -> None: super().__init__(driver, filter_query) self.resolve_properties = resolve_properties or ["name"] self.similarity_threshold = similarity_threshold self.neo4j_database = neo4j_database @abc.abstractmethod def compute_similarity(self, text_a: str, text_b: str) -> float: """ Compute similarity between two textual strings. """ pass async def run(self) -> ResolutionStats: match_query = "MATCH (entity:__Entity__)" if self.filter_query: match_query += f" {self.filter_query}" # generate a dynamic map of requested properties, e.g. "name: entity.name, description: entity.description, ..." props_map_list = [f"{prop}: entity.{prop}" for prop in self.resolve_properties] props_map = ", ".join(props_map_list) # Cypher query: # matches extracted entities # filters entities if filter_query is provided # unwinds labels to skip reserved ones # collects all properties needed for the calculation of similarity query = f""" {match_query} UNWIND labels(entity) AS lab WITH lab, entity WHERE NOT lab IN ['__Entity__', '__KGBuilder__'] WITH lab, collect({{ id: elementId(entity), {props_map} }}) AS labelCluster RETURN lab, labelCluster """ records, _, _ = self.driver.execute_query(query, database_=self.neo4j_database) total_entities = 0 total_merged_nodes = 0 # for each row, 'lab' is the label, 'labelCluster' is a list of dicts (id + textual properties) for row in records: entities = row["labelCluster"] node_texts = {} for ent in entities: # concatenate all textual properties (if non-null) into a single string texts = [ str(ent[p]) for p in self.resolve_properties if p in ent and ent[p] ] combined_text = " ".join(texts).strip() if combined_text: node_texts[ent["id"]] = combined_text total_entities += len(node_texts) # compute pairwise similarity and mark those above the threshold pairs_to_merge = [] for (id1, text1), (id2, text2) in combinations(node_texts.items(), 2): sim = self.compute_similarity(text1, text2) if sim >= self.similarity_threshold: pairs_to_merge.append({id1, id2}) # consolidate overlapping pairs into unique merge sets. merged_sets = self._consolidate_sets(pairs_to_merge) # perform merges in the db using APOC. merged_count = 0 for node_id_set in merged_sets: if len(node_id_set) > 1: merge_query = ( "MATCH (n) WHERE elementId(n) IN $ids " "WITH collect(n) AS nodes " "CALL apoc.refactor.mergeNodes(nodes, {properties: 'discard', mergeRels: true}) " "YIELD node RETURN elementId(node)" ) result, _, _ = self.driver.execute_query( merge_query, {"ids": list(node_id_set)}, database_=self.neo4j_database, ) merged_count += len(result) total_merged_nodes += merged_count return ResolutionStats( number_of_nodes_to_resolve=total_entities, number_of_created_nodes=total_merged_nodes, ) @staticmethod def _consolidate_sets(pairs: List[set[str]]) -> List[set[str]]: """Consolidate overlapping sets of node pairs into unique sets.""" consolidated: List[set[str]] = [] for pair in pairs: merged = False for cons in consolidated: # if there is any intersection, unify them if pair & cons: cons.update(pair) merged = True break if not merged: consolidated.append(set(pair)) return consolidated
[docs] class SpaCySemanticMatchResolver(BasePropertySimilarityResolver): """ Resolve entities with same label and similar set of textual properties (default is ["name"]) based on spaCy's static embeddings and cosine similarities. Args: driver (neo4j.Driver): The Neo4j driver to connect to the database. filter_query (Optional[str]): Optional Cypher WHERE clause to reduce the resolution scope. resolve_properties (Optional[List[str]]): The list of properties to consider for embeddings Defaults to ["name"]. similarity_threshold (float): The similarity threshold above which nodes are merged. Defaults to 0.8. spacy_model (str): The name of the spaCy model to load. Defaults to "en_core_web_lg". neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_). Example: .. code-block:: python from neo4j import GraphDatabase from neo4j_graphrag.experimental.components.resolver import SpaCySemanticMatchResolver URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") DATABASE = "neo4j" driver = GraphDatabase.driver(URI, auth=AUTH) resolver = SpaCySemanticMatchResolver(driver=driver, neo4j_database=DATABASE) await resolver.run() # no expected parameters """ def __init__( self, driver: neo4j.Driver, filter_query: Optional[str] = None, resolve_properties: Optional[List[str]] = None, similarity_threshold: float = 0.8, spacy_model: str = "en_core_web_lg", neo4j_database: Optional[str] = None, ) -> None: if not IS_SPACY_INSTALLED: raise ImportError("""`spacy` python module needs to be installed to use the SpaCySemanticMatchResolver. Install it with: `pip install "neo4j-graphrag[nlp]"` """) super().__init__( driver, filter_query, resolve_properties, similarity_threshold, neo4j_database, ) self.nlp = self._load_or_download_spacy_model(spacy_model) self.embedding_cache: dict[str, NDArray[np.float64]] = {}
[docs] async def run(self) -> ResolutionStats: return await super().run()
def compute_similarity(self, text_a: str, text_b: str) -> float: emb1 = self._get_embedding(text_a) emb2 = self._get_embedding(text_b) sim = self._cosine_similarity( np.asarray(emb1, dtype=np.float64), np.asarray(emb2, dtype=np.float64) ) return sim def _get_embedding(self, text: str) -> NDArray[np.float64]: if text not in self.embedding_cache: embedding = np.asarray(self.nlp(text).vector, dtype=np.float64) self.embedding_cache[text] = embedding return self.embedding_cache[text] @staticmethod def _cosine_similarity( vec1: NDArray[np.float64], vec2: NDArray[np.float64] ) -> float: """Calculate cosine similarity between two embedding vectors.""" dot_product = np.dot(vec1, vec2) norm1 = np.linalg.norm(vec1) norm2 = np.linalg.norm(vec2) if not norm1 or not norm2: return 0.0 return float(dot_product / (norm1 * norm2)) @staticmethod def _load_or_download_spacy_model(model_name: str) -> Language: """ Attempt to load the specified spaCy model by name. If not installed, automatically download and then load it. """ try: return spacy.load(model_name) except OSError as e: # handling cases where the spaCy model is not yet downloaded: if "doesn't seem to be a Python package or a valid path" in str(e): logger.info(f"Model '{model_name}' not found. Downloading...") spacy_download(model_name) return spacy.load(model_name) else: raise e
[docs] class FuzzyMatchResolver(BasePropertySimilarityResolver): """ Resolve entities with the same label and similar set of textual properties using RapidFuzz for fuzzy matching. Similarity scores are normalized to a value between 0 and 1. """ def __init__( self, driver: neo4j.Driver, filter_query: Optional[str] = None, resolve_properties: Optional[List[str]] = None, similarity_threshold: float = 0.8, neo4j_database: Optional[str] = None, ) -> None: if not IS_RAPIDFUZZ_INSTALLED: raise ImportError("""`rapidfuzz` python module needs to be installed to use the SpaCySemanticMatchResolver. Install it with: `pip install "neo4j-graphrag[fuzzy-matching]"` """) super().__init__( driver, filter_query, resolve_properties, similarity_threshold, neo4j_database, )
[docs] async def run(self) -> ResolutionStats: return await super().run()
def compute_similarity(self, text_a: str, text_b: str) -> float: # RapidFuzz's fuzz.WRatio returns a score from 0 to 100 # normalize the input strings before the comparison is done (processor=utils.default_process) # e.g., lowercase the text, strip whitespace, and remove punctuation # normalize the score to the 0..1 range return fuzz.WRatio(text_a, text_b, processor=utils.default_process) / 100.0