ably.do/hft.py

160 lines
5.7 KiB
Python

import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import Dataset
from PIL import Image
import re
import pytesseract
import docx2txt
import PyPDF2
import json
from huggingface_hub import login
login(f"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")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def load_file_catalog(catalog_path):
with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(file)
def identify_legal_document(filename, file_catalog):
return file_catalog.get(filename, f"")
# Funkcja do ekstrakcji tekstu z różnych typów plików
def extract_text_from_file(file_path):
_, ext = os.path.splitext(file_path)
ext = ext.lower()
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
elif ext == '.pdf':
text = ""
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text()
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return pytesseract.image_to_string(Image.open(file_path))
else:
return ""
# Przygotowanie danych
def prepare_dataset(directory, catalog_path):
file_catalog = load_file_catalog(catalog_path)
data = []
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
text = extract_text_from_file(file_path)
if text:
# Sprawdzenie, czy plik znajduje się w katalogu
doc_type = identify_legal_document(file, file_catalog)
if doc_type != "Opracowanie własne":
# Przetwarzanie dla aktów prawnych
articles = re.split(r'(Art\.\s+\d+\.)', text)[1:]
for i in range(0, len(articles), 2):
if i + 1 < len(articles):
article_number = articles[i].strip()
article_content = articles[i + 1].strip()
data.append({
"text": f"{article_number} {article_content}",
"source": f"{doc_type}, {article_number}"
})
else:
# Przetwarzanie dla zwykłych dokumentów
chunks = [text[i:i + 512] for i in range(0, len(text), 512)]
for chunk in chunks:
data.append({
"text": chunk,
"source": f""
})
return data
# Tokenizacja danych
def tokenize_function(examples):
inputs = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
inputs["labels"] = inputs["input_ids"].copy()
inputs["source"] = examples["source"]
return inputs
# Dostosowany model
class CustomModel(AutoModelForCausalLM):
def __init__(self, config):
super().__init__(config)
self.source_embedding = nn.Embedding(1000, config.hidden_size) # Zakładamy maksymalnie 1000 różnych źródeł
def forward(self, input_ids, attention_mask=None, labels=None, source=None):
outputs = super().forward(input_ids, attention_mask=attention_mask, labels=labels)
if source is not None:
source_embeds = self.source_embedding(source)
outputs.logits += source_embeds.unsqueeze(1)
return outputs
# Dostosowany Trainer
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
source = inputs.pop("source", None) # Użyj None jako wartości domyślnej
outputs = model(**inputs, labels=labels)
loss = outputs.loss
return (loss, outputs) if return_outputs else loss
# Przygotowanie modelu i tokenizera
model_name = "google/gemma-2-2b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = CustomModel.from_pretrained(model_name)
# Przygotowanie datasetu
catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path)
dataset = Dataset.from_list(data)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
# Konfiguracja treningu
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
save_steps=10_000,
save_total_limit=2,
)
# Inicjalizacja Trainera
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
)
# Trening modelu
trainer.train()
# Zapisanie modelu
trainer.save_model("./gemma2_finetuned")
# Funkcja do generowania odpowiedzi z cytowaniem
def generate_answer(question, model, tokenizer, dataset):
inputs = tokenizer(question, return_tensors="pt")
outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, output_scores=True, return_dict_in_generate=True)
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Znajdź najbardziej prawdopodobne źródło
source_probs = outputs.scores[-1][:, model.source_embedding.weight.shape[0]:]
most_likely_source_idx = torch.argmax(source_probs).item()
most_likely_source = dataset[most_likely_source_idx % len(dataset)]['source']
return f"{answer}\n\nŹródło: {most_likely_source}"
# Przykład użycia
question = "Ile dni urlopu przysługuje pracownikowi?"
answer = generate_answer(question, model, tokenizer, dataset)
print(answer)