Skip to content

Focal loss optimisation#1236

Open
vedantdalimkar wants to merge 4 commits intoqubvel-org:mainfrom
vedantdalimkar:focal_loss_optimisation
Open

Focal loss optimisation#1236
vedantdalimkar wants to merge 4 commits intoqubvel-org:mainfrom
vedantdalimkar:focal_loss_optimisation

Conversation

@vedantdalimkar
Copy link
Copy Markdown
Contributor

@vedantdalimkar vedantdalimkar commented Sep 14, 2025

This PR addresses #1235

The current focal loss implementation iterates over each class and calculates focal loss in a class-wise manner. This is slightly inefficient and can be optimised by vectorising the loss computation in multiclass mode. Also, the current implementation uses expensive masking operations for filtering out pixels belonging to ignore_index class

I have also attached a notebook that benchmarks the new approach against the old one. The time improvement is significant, often speeding up the code by more than 2x! The notebook also shows that the output of the new function is consistent with the new one.

@qubvel let me know if I need to add anymore tests.

@vedantdalimkar
Copy link
Copy Markdown
Contributor Author

@qubvel Gentle reminder, if you can take a look at this it would be great!

Copy link
Copy Markdown
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, can you please make sure formatting is passing to run the tests. Also it would be nice to add some test case to make sure it works as expected, thanks!

Comment on lines +77 to +79
y_true[y_true == self.ignore_index] = num_classes
y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes + 1)
y_true_one_hot = y_true_one_hot[ ... , : -1]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
y_true[y_true == self.ignore_index] = num_classes
y_true_one_hot = torch.nn.functional.one_hot(y_true,num_classes = num_classes + 1)
y_true_one_hot = y_true_one_hot[ ... , : -1]
y_true[y_true == self.ignore_index] = num_classes
y_true_one_hot = torch.nn.functional.one_hot(y_true, num_classes = num_classes + 1)
y_true_one_hot = y_true_one_hot[ ... , :-1]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants