62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
import mailbox
|
|
import email
|
|
from email.policy import default
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
|
|
|
|
def rolling_window(iterable, n):
|
|
"""Passes a rolling window over the iterable, yielding each n-length tuple.
|
|
rolling_window(range(5), 3) -> (0, 1, 2), (1, 2, 3), (2, 3, 4)
|
|
"""
|
|
it = iter(iterable)
|
|
try:
|
|
window = [next(it) for _ in range(n)]
|
|
while True:
|
|
yield tuple(window)
|
|
window = window[1:] + [next(it)]
|
|
except StopIteration:
|
|
return
|
|
|
|
|
|
def read_mail_text(mbox_path):
|
|
"""
|
|
Extract and concatenate all plaintext content from an mbox file.
|
|
"""
|
|
texts = []
|
|
mbox = mailbox.mbox(
|
|
mbox_path,
|
|
factory=lambda f: email.message_from_binary_file(f, policy=default)
|
|
)
|
|
for msg in tqdm(mbox):
|
|
if msg.is_multipart():
|
|
for part in msg.walk():
|
|
if part.get_content_type() == "text/plain":
|
|
try:
|
|
text = part.get_content()
|
|
if text:
|
|
texts.append(text.strip())
|
|
except Exception:
|
|
pass
|
|
else:
|
|
if msg.get_content_type() == "text/plain":
|
|
try:
|
|
text = msg.get_content()
|
|
if text:
|
|
texts.append(text.strip())
|
|
except Exception:
|
|
pass
|
|
return "\n\n".join(texts)
|
|
|
|
|
|
def join(tokens, punctuation=".,?!:;'\""):
|
|
"Joins text, but does not give extra space before punctuation."
|
|
tokens = [t if t in punctuation else ' ' + t for t in tokens]
|
|
return ''.join(tokens).strip()
|
|
|
|
|
|
def softmax(x):
|
|
"Convert a vector of scores (logits) into a probability distribution."
|
|
e = np.exp(x - x.max())
|
|
return e / e.sum()
|