Revisions

This commit is contained in:
Chris Proctor
2026-06-08 15:15:52 -04:00
parent 1c99ca8bd3
commit 49c4e43f45
12 changed files with 209 additions and 34 deletions

View File

@@ -66,9 +66,13 @@ class MLPClassifier:
self._model = model
return self
def predict(self, X):
def predict_proba(self, X):
X_te = torch.tensor(X, dtype=torch.float32)
self._model.eval()
with torch.no_grad():
predictions = self._model(X_te.to(self._device)).argmax(dim=1).cpu().numpy()
return predictions
logits = self._model(X_te.to(self._device))
probabilities = torch.softmax(logits, dim=1).cpu().numpy()
return probabilities
def predict(self, X):
return self.predict_proba(X).argmax(axis=1)