Revisions

This commit is contained in:
Chris Proctor
2026-06-08 15:15:52 -04:00
parent 1c99ca8bd3
commit 49c4e43f45
12 changed files with 209 additions and 34 deletions

View File

@@ -72,9 +72,13 @@ class CNNClassifier:
self._model = model
return self
def predict(self, X):
def predict_proba(self, X):
X_te = torch.tensor(X, dtype=torch.float32)
self._model.eval()
with torch.no_grad():
predictions = self._model(X_te.to(self._device)).argmax(dim=1).cpu().numpy()
return predictions
logits = self._model(X_te.to(self._device))
probabilities = torch.softmax(logits, dim=1).cpu().numpy()
return probabilities
def predict(self, X):
return self.predict_proba(X).argmax(axis=1)

View File

@@ -34,7 +34,7 @@ class FeatureExtractor:
}
class HandPickedClassifier:
class FeatureClassifier:
def fit(self, X, y):
self._pipeline = Pipeline([
("features", FeatureExtractor()),
@@ -46,3 +46,6 @@ class HandPickedClassifier:
def predict(self, X):
return self._pipeline.predict(X)
def predict_proba(self, X):
return self._pipeline.predict_proba(X)

View File

@@ -66,9 +66,13 @@ class MLPClassifier:
self._model = model
return self
def predict(self, X):
def predict_proba(self, X):
X_te = torch.tensor(X, dtype=torch.float32)
self._model.eval()
with torch.no_grad():
predictions = self._model(X_te.to(self._device)).argmax(dim=1).cpu().numpy()
return predictions
logits = self._model(X_te.to(self._device))
probabilities = torch.softmax(logits, dim=1).cpu().numpy()
return probabilities
def predict(self, X):
return self.predict_proba(X).argmax(axis=1)

View File

@@ -9,3 +9,6 @@ class PixelClassifier:
def predict(self, X):
return self._classifier.predict(X)
def predict_proba(self, X):
return self._classifier.predict_proba(X)