ably.do/hft.py

286 lines
11 KiB
Python

import os
import torch
import torch.nn as nn
import re
import json
import PyPDF2
import docx2txt
import pytesseract
from PIL import Image
from collections import defaultdict
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import Dataset, Features, Value
from huggingface_hub import login
# Konfiguracja
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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")
class LegalAITrainer:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: len(self.source_to_idx))
self.idx_to_source = {}
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.source_to_idx[source]
self.idx_to_source[idx] = source
def get_idx(self, source):
return self.source_to_idx[source] if source else -1
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
class LegalModel(nn.Module):
def __init__(self, model_name, config):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
self.source_embedding = nn.Embedding(100000, config.hidden_size, padding_idx=-1)
self.confidence_layer = nn.Linear(config.hidden_size, 1)
for param in self.base_model.parameters():
param.requires_grad = False
for layer in [self.source_embedding, self.confidence_layer]:
for param in layer.parameters():
param.requires_grad = True
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None):
if source_idx is not None:
source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(source_idx).unsqueeze(1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
outputs = self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
else:
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
confidence = torch.sigmoid(self.confidence_layer(outputs.hidden_states[-1].mean(dim=1)))
return {
"loss": outputs.loss,
"logits": outputs.logits,
"confidence": confidence,
"hidden_states": outputs.hidden_states
}
def load_file_catalog(self, catalog_path):
try:
with open(catalog_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"Błąd ładowania katalogu: {str(e)}")
return {}
def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try:
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == '.pdf':
text = ""
with open(file_path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
for page in reader.pages:
text += page.extract_text() or ""
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']:
return pytesseract.image_to_string(Image.open(file_path))
else:
return ""
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
return ""
def prepare_data(self, data_dir, catalog_path):
catalog = self.load_file_catalog(catalog_path)
data = []
source_mapper = self.SourceMapper()
for root, _, files in os.walk(data_dir):
for file in files:
file_path = os.path.join(root, file)
text = self.extract_text(file_path)
if not text:
continue
doc_type = catalog.get(os.path.splitext(file)[0].lower(), "Opracowanie własne")
if doc_type != "Opracowanie własne":
articles = re.split(r'(?i)(Art\.\s*\d+[a-z]*)', text)
for i in range(1, len(articles), 2):
art_num = articles[i].strip()
content = articles[i+1].strip()
if len(content) < 100:
continue
source = f"{doc_type}, {art_num}"
source_mapper.add_source(source)
data.append({
"text": f"[LEGAL] {art_num} {content}",
"source_idx": source_mapper.get_idx(source),
"is_legal": 1
})
else:
chunks = [f"[GENERAL] {text[i:i+512]}" for i in range(0, len(text), 512)]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": -1,
"is_legal": 0
})
features = Features({
"text": Value("string"),
"source_idx": Value("int32"),
"is_legal": Value("int32")
})
return Dataset.from_dict({
"text": [d["text"] for d in data],
"source_idx": [d["source_idx"] for d in data],
"is_legal": [d["is_legal"] for d in data]
}, features=features), source_mapper
def train(self, model_name="crumb/nano-mistral", data_dir="data", catalog_path="catalog.json"):
dataset, source_mapper = self.prepare_data(data_dir, catalog_path)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
def tokenize_fn(examples):
tokenized = tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
)
return {
"input_ids": tokenized["input_ids"].squeeze(),
"attention_mask": tokenized["attention_mask"].squeeze(),
"labels": tokenized["input_ids"].squeeze().clone(),
"source_idx": torch.tensor(examples["source_idx"], dtype=torch.int32)
}
tokenized_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = self.LegalModel(model_name, config).to(self.device)
training_args = TrainingArguments(
output_dir="./legal_ai_model",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=50,
save_strategy="steps",
save_steps=500,
report_to="none",
remove_unused_columns=False
)
class LegalTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
loss = outputs["loss"]
target_conf = (inputs["source_idx"] != -1).float()
conf_loss = nn.BCELoss()(outputs["confidence"].squeeze(), target_conf)
total_loss = loss + 0.7 * conf_loss
return (total_loss, outputs) if return_outputs else total_loss
trainer = LegalTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
)
print("Rozpoczęcie treningu...")
trainer.train()
model.save_pretrained("./trained_legal_ai")
tokenizer.save_pretrained("./trained_legal_ai")
with open("./trained_legal_ai/source_mapper.json", "w") as f:
json.dump(source_mapper.idx_to_source, f)
print("Trening zakończony!")
def generate_response(self, prompt, confidence_threshold=0.65):
model = self.LegalModel.from_pretrained(
"./trained_legal_ai",
config=AutoModelForCausalLM.from_pretrained("crumb/nano-mistral").config
).to(self.device)
tokenizer = AutoTokenizer.from_pretrained("./trained_legal_ai")
with open("./trained_legal_ai/source_mapper.json", "r") as f:
source_mapper = json.load(f)
inputs = tokenizer(
f"[PROMPT] {prompt} [RESPONSE]",
return_tensors="pt",
max_length=512,
truncation=True
).to(self.device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=512,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
output_scores=True,
return_dict_in_generate=True
)
full_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
confidence = torch.sigmoid(outputs.scores[-1][:, tokenizer.eos_token_id]).item()
citations = list(set(re.findall(r"Art\.\s*\d+[a-z]*", full_text)))
verified = [c for c in citations if any(c in s for s in source_mapper.values())]
if confidence < confidence_threshold or not verified:
return "Nie mogę udzielić jednoznacznej odpowiedzi na podstawie dostępnych danych."
else:
return f"{full_text}\n\nPotwierdzone źródła: {', '.join(verified)}"
if __name__ == "__main__":
legal_ai = LegalAITrainer()
legal_ai.train(
model_name="crumb/nano-mistral",
data_dir="./legal_docs",
catalog_path="./catalog.json"
)
test_prompt = "Jakie są kary za prowadzenie pojazdu pod wpływem alkoholu?"
print(legal_ai.generate_response(test_prompt))