This commit is contained in:
l.gabrysiak 2025-02-25 22:21:41 +01:00
parent c55fbe8632
commit 2f6da20984
1 changed files with 134 additions and 105 deletions

239
hft.py
View File

@ -19,12 +19,12 @@ from transformers import (
DataCollatorForLanguageModeling
)
from datasets import Dataset
from nlpaug import Augmenter, CharAugmenter, WordAugmenter
from nlpaug.augmenter.word import WordAugmenter
from huggingface_hub import login
# Konfiguracja
os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") # Zastąp prawdziwym tokenem
login(token="TWÓJ_TOKEN_HF") # Zastąp swoim tokenem
class SourceMapper:
def __init__(self):
@ -42,10 +42,10 @@ class SourceMapper:
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
class LegalDataProcessor:
class LegalProcessor:
def __init__(self, catalog_path):
self.catalog = self.load_catalog(catalog_path)
self.augmenter = WordAugmenter.AntonymAug()
self.augmenter = self.init_augmenter()
def load_catalog(self, path):
try:
@ -54,19 +54,26 @@ class LegalDataProcessor:
except:
return defaultdict(str)
def identify_document(self, filename):
base = os.path.splitext(filename)[0].lower()
return self.catalog.get(base, "Opracowanie własne")
def init_augmenter(self):
return WordAugmenter.SynonymAug(aug_src='wordnet', aug_max=3)
def process_file(self, file_path):
text = self.extract_text(file_path)
if not text:
return []
doc_type = self.identify_doc_type(file_path)
return self.split_content(text, doc_type)
def extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == '.pdf':
return self._extract_pdf(file_path)
return self.extract_pdf(file_path)
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png']:
return self._extract_ocr(file_path)
return self.extract_image(file_path)
else:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
@ -74,7 +81,7 @@ class LegalDataProcessor:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
return ""
def _extract_pdf(self, path):
def extract_pdf(self, path):
text = ""
with open(path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
@ -82,116 +89,139 @@ class LegalDataProcessor:
text += page.extract_text() + "\n"
return re.sub(r'\s+', ' ', text)
def _extract_ocr(self, path):
def extract_image(self, path):
return pytesseract.image_to_string(
Image.open(path),
config='--psm 4 --oem 3 -c preserve_interword_spaces=1'
)
def process_legal(self, text, doc_type):
articles = re.split(
r'(?ix)(Art\.?\s*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)\b',
text
)
processed = []
def identify_doc_type(self, file_path):
base = os.path.splitext(os.path.basename(file_path))[0].lower()
return self.catalog.get(base, "Custom")
def split_content(self, text, doc_type):
if doc_type == "Custom":
return self.split_custom(text)
return self.split_legal(text, doc_type)
def split_legal(self, text, doc_type):
pattern = r'(?i)(Art[\.\s]*\d+[a-z]*|§\s*\d+|Rozdział\s+[IVXLCDM]+)'
parts = re.split(pattern, text)
results = []
current_header = ""
for item in articles:
if item and re.match(r'(?i)(Art|§|Rozdział)', item):
for part in parts:
if not part:
continue
if re.match(pattern, part):
if current_header:
processed.append(current_header)
current_header = item.strip()
elif current_header:
processed.append(current_header + " " + item.strip())
current_header = ""
results.append(current_header)
current_header = f"[{doc_type}] {part.strip()}"
else:
processed.append(item.strip())
if current_header:
results.append(f"{current_header}: {part.strip()}")
current_header = ""
else:
results.append(part.strip())
return [
(f"[{doc_type}] {p}", doc_type)
for p in processed if len(p) > 30
]
return [text for text in results if len(text) > 50]
def process_custom(self, text):
def split_custom(self, text):
clean_text = re.sub(r'\s+', ' ', text).strip()
chunk_size = 384
overlap = 128
overlap = 64
chunks = [
clean_text[i:i+chunk_size]
for i in range(0, len(clean_text), chunk_size - overlap)
]
return [("[Custom] " + c, "Custom") for c in chunks if c.strip()]
chunks = []
start = 0
while start < len(clean_text):
end = start + chunk_size
chunks.append(clean_text[start:end])
start = end - overlap
return [f"[Custom] {chunk}" for chunk in chunks if chunk.strip()]
class EnhancedDataCollator(DataCollatorForLanguageModeling):
def torch_call(self, examples):
batch = super().torch_call(examples)
if "source_idx" in examples[0]:
batch["source_idx"] = torch.tensor(
[ex["source_idx"] for ex in examples],
dtype=torch.long
)
return batch
class CustomModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(model_name)
self.source_emb = torch.nn.Embedding(1000, self.base_model.config.hidden_size)
# Zamrożenie parametrów bazowych
for param in self.base_model.parameters():
param.requires_grad = False
# Odmrożenie ostatnich warstw
for layer in self.base_model.transformer.h[-2:]:
for param in layer.parameters():
param.requires_grad = True
self.base_model.get_output_embeddings().requires_grad_(True)
def forward(self, input_ids, attention_mask, labels, source_idx):
inputs_embeds = self.base_model.get_input_embeddings()(input_ids)
source_emb = self.source_emb(source_idx.clamp(0, 999)).unsqueeze(1)
inputs_embeds += source_emb
return self.base_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
def main():
# Konfiguracja
# Inicjalizacja komponentów
source_mapper = SourceMapper()
processor = LegalDataProcessor("file_catalog.json")
processor = LegalProcessor("file_catalog.json")
tokenizer = AutoTokenizer.from_pretrained("crumb/nano-mistral")
tokenizer.pad_token = tokenizer.eos_token
# Przygotowanie danych
# Przetwarzanie danych
data = []
def process_file(file_path):
nonlocal data
text = processor.extract_text(file_path)
if not text:
return
doc_type = processor.identify_document(os.path.basename(file_path))
if doc_type != "Opracowanie własne":
processed = processor.process_legal(text, doc_type)
else:
processed = processor.process_custom(text)
for text, source in processed:
source_mapper.add_source(source)
data.append({
"text": text,
"source_idx": source_mapper.get_idx(source)
})
def process_and_augment(file_path):
try:
items = processor.process_file(file_path)
for text in items:
source = text.split("]")[0][1:]
source_mapper.add_source(source)
# Oryginalny tekst
data.append({
"text": text,
"source_idx": source_mapper.get_idx(source)
})
# Augmentacja - 2 warianty
for _ in range(2):
words = text.split()
if len(words) > 5:
# Losowa zamiana kolejności słów
random.shuffle(words)
augmented = " ".join(words)
data.append({
"text": augmented,
"source_idx": source_mapper.get_idx(source)
})
except Exception as e:
print(f"Błąd przetwarzania {file_path}: {str(e)}")
# Przetwarzanie wielowątkowe
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
futures = []
for root, _, files in os.walk("files"):
for root, _, files in os.walk("files"): # Folder z danymi
for file in files:
futures.append(executor.submit(
process_file,
os.path.join(root, file)
))
file_path = os.path.join(root, file)
futures.append(executor.submit(process_and_augment, file_path))
for future in futures:
try:
future.result()
except Exception as e:
print(f"Błąd: {str(e)}")
# Augmentacja
print(f"Przed augmentacją: {len(data)} przykładów")
augmented = []
for item in data:
for _ in range(2): # 2 dodatkowe warianty
sentences = item['text'].split('. ')
random.shuffle(sentences)
augmented.append({
"text": '. '.join(sentences),
"source_idx": item["source_idx"]
})
data += augmented
print(f"Po augmentacji: {len(data)} przykładów")
future.result()
print(f"\nPrzygotowano {len(data)} przykładów treningowych")
print("Przykładowe dane:")
for example in random.sample(data, 3):
print(f"\nŹródło: {source_mapper.get_source(example['source_idx'])}")
print(f"Tekst: {example['text'][:150]}...")
# Przygotowanie datasetu
dataset = Dataset.from_list(data)
@ -216,21 +246,19 @@ def main():
batch_size=32,
num_proc=4
)
# Inicjalizacja modelu
model = CustomModel("crumb/nano-mistral")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Model
model = AutoModelForCausalLM.from_pretrained(
"crumb/nano-mistral",
trust_remote_code=True
)
model.resize_token_embeddings(len(tokenizer))
# Trening
# Konfiguracja treningu
training_args = TrainingArguments(
output_dir="./results",
output_dir="./wyniki",
num_train_epochs=5,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=1e-4,
learning_rate=2e-5,
fp16=torch.cuda.is_available(),
logging_steps=20,
save_strategy="epoch",
@ -241,16 +269,17 @@ def main():
model=model,
args=training_args,
train_dataset=tokenized_ds,
data_collator=EnhancedDataCollator(tokenizer=tokenizer, mlm=False)
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
)
print("Rozpoczęcie treningu...")
# Trening
print("\nRozpoczynanie treningu...")
trainer.train()
print("Trening zakończony!")
# Zapisz model
model.save_pretrained("./trained_model")
tokenizer.save_pretrained("./trained_model")
# Zapis modelu
model.save_pretrained("./trained_legal_model")
tokenizer.save_pretrained("./trained_legal_model")
print("Trening zakończony pomyślnie!")
if __name__ == "__main__":
main()