Back to Article
4. Article summarization code
Download Notebook

4. Article summarization code

In [2]:
import os
import ast
import math
import json
import time
import pandas as pd
from dotenv import load_dotenv
import prompt_templates_summarization as pts
from langchain.schema import BaseOutputParser
from langchain.prompts.chat import ChatPromptTemplate
from langchain_google_genai import (
    ChatGoogleGenerativeAI,
    HarmBlockThreshold,
    HarmCategory
)
from json.decoder import JSONDecodeError
from google.generativeai.types import BlockedPromptException
from google.generativeai.types.generation_types import StopCandidateException
In [3]:
country = "Italy"
path2SP = ".../EU-S Data/"

Loading API key

In [4]:
load_dotenv()
GoogleAI_key = os.getenv("googleAI_API_key")
os.environ['GOOGLE_API_KEY'] = GoogleAI_key

Loading data

In [5]:
def showEverything(df):
    with pd.option_context('display.max_rows', None,
                        'display.max_columns', None,
                        'display.width', 1000,
                        'display.precision', 3,
                        'display.colheader_justify', 'left'):
        display(df)
In [4]:
country_data      = pd.read_parquet(f"{path2SP}/EU-S Data/Automated Qualitative Checks/Data/data-classification-1/0_compiled/{country}_classified.parquet.gzip")
location_variable = f"location_{country}"
if country == "Czechia":
    location_variable = f"location_Czech"
subset_data       = (
    country_data.copy()
    .loc[(country_data[location_variable] == True) & (country_data["topic_related"].str.contains("Related|Justice|Governance"))]
)
subset_data.shape
In [5]:
pillar_columns = subset_data.filter(like='pillar_')
pillar_sum = pillar_columns.sum(axis=0)
pillar_sum

Defining chain

In [8]:
safety_settings = {
    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE
}
In [9]:
class JSONOutputParser(BaseOutputParser):
    def parse(self, text: str):
        """
        Parse the output of an LLM call to a valid JSON format.
        """
        return json.loads(text.replace('```json', '').replace('```', ''), strict=False)
In [19]:
def summarize_article(headline, summary, body, pillar):
    """
    This function takes a headline, a summary, and the content of a news article and it sends a call to Google's Gemini
    to summarize a article and provide an impact score focusing on a specific pillar of the Rule of Law.
    """
    
    idx = str(pillar)

    # Setting up the Prompt Template
    chat_prompt = ChatPromptTemplate.from_messages([
                    ("system", pts.context),
                    ("human", pts.instructions)
                ])

    # Defining our chain
    chain_gemini = chat_prompt | ChatGoogleGenerativeAI(model = "gemini-pro",
                                                        temperature = 0.1, 
                                                        safety_settings = safety_settings,
                                                        convert_system_message_to_human = True) | JSONOutputParser()
    
    try: 
        llm_response = chain_gemini.invoke({
            "headline"       : headline,
            "summary"        : summary,
            "body"           : body,
            "pillar_name"    : pts.pillar_names[idx],
            "pillar_bullets" : pts.pillar_bullets[idx]
        })
        status = True
        time.sleep(3)   # We need to slow down the calls. given that the Gemini API has a limit of 60 calls per second

    # The API can still block some of our prompts due to undefined reasons. Sadly, we can't do anything about it, so we
    # predefine the outcome    
    except (BlockedPromptException, StopCandidateException):
        print("Prompt BLOCKED")
        status = False
    
    except JSONDecodeError:
        print("Decode error... trying again...")
        try: 
            llm_response = chain_gemini.invoke({
                "headline"       : headline,
                "summary"        : summary,
                "body"           : body,
                "pillar_name"    : pts.pillar_names[idx],
                "pillar_bullets" : pts.pillar_bullets[idx]
            })
            status = True
            time.sleep(3)

        except JSONDecodeError:
            print("Failed. Skipping article.")
            status = False

    # We use the STATUS variable to throw an outcome to our call depending if our prompt was blocked or not
    if status == True:
        outcome = [llm_response["summary"], llm_response["impact_score"]]

    else:
        outcome = "Skipped article"

    return outcome

Sending calls in sets and batches

In [8]:
try:
    os.mkdir(f"{path2SP}/EU-S Data/Automated Qualitative Checks/Data/data-summarization/{country}")
    print("Directory created")
except FileExistsError:
    print("Directory already exists")
try:
    for p in range(1,9):
        os.mkdir(f"{path2SP}/EU-S Data/Automated Qualitative Checks/Data/data-summarization/{country}/pillar_{p}")
    print("Sub-directories created")
except FileExistsError:
    print("Sub-directories already exists")
In [9]:
for p in range(1,9):

    print("=========================")
    print(f"Starting with PILLAR {p}")
    print("=========================")

    pillar_subset = (
        subset_data.copy()
        .loc[subset_data[f"pillar_{p}"] == True]
    )
    nsets = math.ceil(len(pillar_subset)/1000)

    for set in range(1, nsets+!):
        print("==============================================")
        print(f"Starting with SET {set} out of {nsets} set(s)")
        print("==============================================")

        results = []

        for batch_number in range(1,11):
            starting_row = ((set-1)*1000)+((batch_number-1)*100)
            end_row      = starting_row+100
            batch_subset = pillar_subset.copy().iloc[starting_row:end_row]

            if len(batch_subset) > 0 :
                print("============================================================================")
                print(f"Sending batch number: {batch_number}, start: {starting_row}, end: {end_row}")
                print("============================================================================")

                batch_subset[["summary", "impact_score"]] = batch_subset.apply(lambda row: pd.Series(summarize_article(
                    row["title_trans"], 
                    row["description_trans"], 
                    row["content_trans"], 
                    p
                )), axis = 1)
                results.append(batch_subset)

        # Collapsing and saving data
        collapsed_data = pd.concat(results).drop_duplicates(subset="id")
        collapsed_data.loc[collapsed_data["impact_score"] == "N/A", "impact_score"] = 0
        collapsed_data["impact_score"] = collapsed_data["impact_score"].fillna(0).astype(int)
        collapsed_data.to_parquet(f"{path2SP}/EU-S Data/Automated Qualitative Checks/Data/data-summarization/{country}/pillar_{p}/{country}_pillar_{p}_set_{set}.parquet.gzip", compression="gzip")
        time.sleep(5)