Back to Article
3. Classification code
Download Notebook

3. Classification code

In [1]:
import os
import ast
import math
import json
import time
import pandas as pd
from dotenv import load_dotenv
import prompt_templates_classification as ptc
from langchain.schema import BaseOutputParser
from langchain.prompts.chat import ChatPromptTemplate
from langchain_google_genai import (
    ChatGoogleGenerativeAI,
    HarmBlockThreshold,
    HarmCategory
)
from json.decoder import JSONDecodeError
from google.generativeai.types import BlockedPromptException
from google.generativeai.types.generation_types import StopCandidateException
In [2]:
country = "Italy"
path2SP = ".../EU-S Data/"

Loading API key

In [3]:
load_dotenv()
GoogleAI_key = os.getenv("googleAI_API_key")
os.environ['GOOGLE_API_KEY'] = GoogleAI_key

Loading data

In [4]:
country_data = pd.read_parquet(f"{path2SP}/EU-S Data/Automated Qualitative Checks/Data/data-extraction-1/ready4class/{country}_translated.parquet.gzip")
country_data.head(5)

Defining Chain

In [7]:
safety_settings = {
    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE
}
In [7]:
class JSONOutputParser(BaseOutputParser):
    def parse(self, text: str):
        """
        Parse the output of an LLM call to a valid JSON format.
        """
        return json.loads(text.replace('```json', '').replace('```', ''), strict=False)
In [8]:
def classify_article(headline, summary, body, id = None, stage_1 = True, relation = None):
    """
    This function takes a headline, a summary, and the content of a news article and it sends a call to Google's Gemini
    to classify the article. There are two different classifications: Stage 1 and Stage 2. If stage_1 is set to TRUE, then
    the call to the model will try to answer the following question: Is this news article related or unrelated to the Rule of Law?
    If stage_1 is set to FALSE, then the call to the model will try to rate how closely related is the news article to each
    one of the eight pillars of the Rule of Law.
    """
    # print(id)

    # Defining the prompt according to which stage are we calling the function for
    if stage_1 == True:
        system_prompt = ptc.context_stage_1
        human_prompt  = ptc.instructions_stage_1
    else:
        system_prompt = ptc.context_stage_2
        human_prompt  = ptc.instructions_stage_2

    # Setting up the Prompt Template
    chat_prompt = ChatPromptTemplate.from_messages([
                    # ("system", system_prompt),
                    ("human", human_prompt),
                ])

    # Defining our chain
    chain_gemini = chat_prompt | ChatGoogleGenerativeAI(model = "gemini-pro",
                                                        temperature     = 0.1, 
                                                        safety_settings = safety_settings,
                                                        convert_system_message_to_human = True) | JSONOutputParser()
    
    # For Stage 2, we don't want to pass articles that were already classified as "UNRELATED", so we pre-defined the outcome
    if stage_1 == False and all(keyword not in relation for keyword in ["Related", "Justice", "Governance"]):
        outcome = "Unrelated"

    else:
        try: 
            llm_response = chain_gemini.invoke({
                "headline": headline,
                "summary" : summary,
                "body"    : body,
            })
            status = True
            time.sleep(1)   # We need to slow down the calls. given that the Gemini API has a limit of 60 calls per second

        # The API can still block some of our prompts due to undefined reasons. Sadly, we can't do anything about it, so we
        # predefine the outcome    
        except (BlockedPromptException, StopCandidateException):
            print("Prompt BLOCKED")
            status = False
        
        except JSONDecodeError:
            print("Decode error... trying again...")
            try: 
                llm_response = chain_gemini.invoke({
                    "headline": headline,
                    "summary" : summary,
                    "body"    : body,
                })
                status = True
                time.sleep(1)
            except JSONDecodeError:
                print("Failed. Skipping article.")
                status = False

        # We use the STATUS variable to throw an outcome to our call depending if our prompt was blocked or not and
        # on the stage we are calling the function for
        if status == True:
            if stage_1 == True:
                if "Governance" in llm_response["rol_related"]:
                    llm_response["rol_related"] == "Related"
                if "Justice" in llm_response["rol_related"]:
                    llm_response["rol_related"] == "Related"
                outcome = [llm_response["rol_related"], llm_response["country"]]

            else:
                outcome = json.dumps(llm_response["pillars_relation"])
        else:
            outcome = "Skipped article"

    return outcome
In [6]:
country_data.shape

Sending calls in sets and batches

In [8]:
try:
    os.mkdir(f"{path2SP}/EU-S Data/Automated Qualitative Checks/Data/data-classification-1/{country}")
    print("Directory created")
except FileExistsError:
    print("Directory already exists")
In [9]:
nsets = math.ceil(len(country_data)/1000)
for set in range(1, nsets+1):
    
    print("=======================================")
    print(f"Starting with SET {set} out of {nsets}")
    print("=======================================")

    results = []

    for batch_number in range(1,11):

        # Subsetting data
        starting_row = ((set-1)*1000)+((batch_number-1)*100)
        end_row      = starting_row+100
        batch_subset = country_data.copy().iloc[starting_row:end_row]

        if len(batch_subset) > 0 :
            print("============================================================================")
            print(f"Sending batch number: {batch_number}, start: {starting_row}, end: {end_row}")
            print("============================================================================")
            
            # Applying classifiers
            print("====== STAGE 1 =====")
            batch_subset[["topic_related", "location_events"]] = batch_subset.apply(lambda row: pd.Series(classify_article(
                row["title_trans"], 
                row["description_trans"], 
                row["content_trans"], 
                row["id"],
                stage_1 = True
            )), axis = 1)

            print("====== STAGE 2 =====")
            batch_subset["pillars_score"] = batch_subset.apply(lambda row: classify_article(
                row["title_trans"], 
                row["description_trans"], 
                row["content_trans"], 
                row["id"],
                relation = row["topic_related"],
                stage_1  = False
            ), axis = 1)

            results.append(batch_subset)

    # Collapsing and saving data
    collapsed_data = pd.concat(results).drop_duplicates(subset="id")
    collapsed_data.to_parquet(f"{path2SP}/EU-S Data/Automated Qualitative Checks/Data/data-classification-1/{country}/{country}_set_{set}.parquet.gzip", compression="gzip")
    time.sleep(5)

Compiling sets

In [8]:
classified_data_list = [
    pd.read_parquet(f"{path2SP}/EU-S Data/Automated Qualitative Checks/Data/data-classification-1/{country}/{file}") 
    for file in os.listdir(f"{path2SP}/EU-S Data/Automated Qualitative Checks/Data/data-classification-1/{country}")
]
classified_data = pd.concat(classified_data_list)

Converting pillar scores to binary

In [9]:
def extract_score(string, pillar, t = 7):
    """
    This function extracts scores from a string and returns a binary value that is equal to 1 if the score is higher/equal
    than a specific threshold, and it returns zero if otherwise.
    """
    try:
        scores_dicts = ast.literal_eval(string)
        ratings = [v for x in scores_dicts for _,v in x.items()]
        keys    = [k for x in scores_dicts for k,_ in x.items()]
        pattern = str(pillar) + ". "
        idx     = next((index for index, element in enumerate(keys) if pattern in element), None)

        if idx is not None:
            score = ratings[idx]
        else:
            score = 0
            
        if score >= t:
            return 1
        else:
            return 0
        
    except ValueError:
        if string == "Unrelated":
            return 0
    
    except SyntaxError:
        if string == "Skipped article":
            return 0
In [10]:
for i in range(1, 9):
    var_name     = "pillar_" + str(i)
    classified_data[var_name] = classified_data["pillars_score"].apply(lambda x: extract_score(x, i))

Cleaning location of events and topic relation

In [11]:
def loc2bin(location, country):
    if pd.isna(location):
        return False
    elif country in location:
        return True
    else:
        return False
In [12]:
eu_member_states = [
    "Austria", "Belgium", "Bulgaria", "Croatia", "Cyprus", "Czech", "Denmark", "Estonia", "Finland", "France",
    "Germany", "Greece", "Hungary", "Ireland", "Italy", "Latvia", "Lithuania", "Luxembourg", "Malta", "Netherlands",
    "Poland", "Portugal", "Romania", "Slovakia", "Slovenia", "Spain", "Sweden", "Euro"
]
for member in eu_member_states:
    var_name = f"location_{member}"
    classified_data[var_name] = classified_data["location_events"].apply(lambda x: loc2bin(x, member))

Saving data

In [13]:
classified_data.to_parquet(f"{path2SP}/EU-S Data/Automated Qualitative Checks/Data/data-classification-1/0_compiled/{country}_classified.parquet.gzip", compression="gzip")