29 lines
673 B
Python
29 lines
673 B
Python
import os
|
|
|
|
import joblib
|
|
|
|
MODEL_FILE = "model.joblib"
|
|
WEIGHTS_DIR = "weights"
|
|
|
|
|
|
def _resolve(name):
|
|
if name.startswith(WEIGHTS_DIR + os.sep) or name.startswith(WEIGHTS_DIR + "/"):
|
|
return name
|
|
return os.path.join(WEIGHTS_DIR, name)
|
|
|
|
|
|
def is_saved_model(path):
|
|
directory = _resolve(path)
|
|
return os.path.isdir(directory) and os.path.exists(os.path.join(directory, MODEL_FILE))
|
|
|
|
|
|
def save_model(clf, name):
|
|
directory = _resolve(name)
|
|
os.makedirs(directory, exist_ok=True)
|
|
joblib.dump(clf, os.path.join(directory, MODEL_FILE))
|
|
|
|
|
|
def load_model(path):
|
|
directory = _resolve(path)
|
|
return joblib.load(os.path.join(directory, MODEL_FILE))
|