Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@
* The result of a {@link RegistrationGuard} evaluation indicating whether
* a registration attempt should be allowed or denied.
*
* <p>Use the static factory methods {@link #allow()} and {@link #deny(String)}
* rather than the canonical constructor. The {@code reason} parameter is only
* meaningful when {@code allowed} is {@code false}.</p>
*
* @param allowed whether the registration is permitted
* @param reason a human-readable denial reason (may be {@code null} when allowed)
* @param reason a human-readable denial reason; only meaningful when {@code allowed}
* is {@code false}, may be {@code null} when allowed
*/
public record RegistrationDecision(boolean allowed, String reason) {

/** Default reason used when {@link #deny(String)} is called with a blank or null reason. */
private static final String DEFAULT_DENIAL_REASON = "Registration denied.";

/**
* Creates a decision that allows the registration to proceed.
*
Expand All @@ -25,8 +33,8 @@ public static RegistrationDecision allow() {
* @return a denying decision with the given reason
*/
public static RegistrationDecision deny(String reason) {
if (reason == null || reason.trim().isEmpty()) {
reason = "Registration denied.";
if (reason == null || reason.isBlank()) {
reason = DEFAULT_DENIAL_REASON;
}
return new RegistrationDecision(false, reason);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
* @see RegistrationDecision
* @see DefaultRegistrationGuard
*/
@FunctionalInterface
public interface RegistrationGuard {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public User handleOAuthLoginSuccess(String registrationId, OAuth2User oAuth2User
log.info("Registration denied for email: {} source: OAUTH2 provider: {} reason: {}",
user.getEmail(), registrationId, decision.reason());
throw new OAuth2AuthenticationException(
new OAuth2Error("registration_denied"), decision.reason());
new OAuth2Error("registration_denied", decision.reason(), null), decision.reason());
}
user = registerNewOAuthUser(registrationId, user);
return user;
Expand All @@ -126,18 +126,17 @@ public User handleOAuthLoginSuccess(String registrationId, OAuth2User oAuth2User
* @param user The User object representing the authenticated user.
* @return A User object representing the newly registered user.
*/
@Transactional
private User registerNewOAuthUser(String registrationId, User user) {
User.Provider provider = User.Provider.valueOf(registrationId.toUpperCase());
user.setProvider(provider);
user.setRoles(Arrays.asList(roleRepository.findByName(USER_ROLE_NAME)));
// We will trust OAuth2 providers to provide us with a verified email address.
user.setEnabled(true);
AuditEvent registrationAuditEvent = AuditEvent.builder().source(this).user(user).action("OAuth2 Registration Success").actionStatus("Success")
User savedUser = userRepository.save(user);
AuditEvent registrationAuditEvent = AuditEvent.builder().source(this).user(savedUser).action("OAuth2 Registration Success").actionStatus("Success")
.message("Registration Confirmed. User logged in.").build();

eventPublisher.publishEvent(registrationAuditEvent);
return userRepository.save(user);
return savedUser;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
package com.digitalsanctuary.spring.user.service;

import java.util.Arrays;
import com.digitalsanctuary.spring.user.persistence.model.User;
import com.digitalsanctuary.spring.user.persistence.repository.RoleRepository;
import com.digitalsanctuary.spring.user.persistence.repository.UserRepository;
import com.digitalsanctuary.spring.user.registration.RegistrationContext;
import com.digitalsanctuary.spring.user.registration.RegistrationDecision;
import com.digitalsanctuary.spring.user.registration.RegistrationGuard;
import com.digitalsanctuary.spring.user.registration.RegistrationSource;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.util.Locale;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
Expand All @@ -19,6 +12,17 @@
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import com.digitalsanctuary.spring.user.audit.AuditEvent;
import com.digitalsanctuary.spring.user.persistence.model.User;
import com.digitalsanctuary.spring.user.persistence.repository.RoleRepository;
import com.digitalsanctuary.spring.user.persistence.repository.UserRepository;
import com.digitalsanctuary.spring.user.registration.RegistrationContext;
import com.digitalsanctuary.spring.user.registration.RegistrationDecision;
import com.digitalsanctuary.spring.user.registration.RegistrationGuard;
import com.digitalsanctuary.spring.user.registration.RegistrationSource;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

/**
* OIDC user service implementation for handling OpenID Connect authentication (Keycloak).
Expand All @@ -36,17 +40,27 @@
*/
@Slf4j
@Service
@Transactional
@RequiredArgsConstructor
public class DSOidcUserService implements OAuth2UserService<OidcUserRequest, OidcUser> {
Comment on lines 41 to 45
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding @Transactional at the class level means the transaction starts before defaultOidcUserService.loadUser(userRequest) runs, so the DB transaction/connection can be held open while making a network call to the OIDC provider. To reduce contention and improve reliability, consider narrowing the transaction scope to only the DB interaction portion (e.g., using a separate transactional component/method invoked via proxy, or TransactionTemplate around handleOidcLoginSuccess + login helper persistence work).

Copilot uses AI. Check for mistakes.

/** The user repository. */
private final UserRepository userRepository;

/** The role repository. */
private final RoleRepository roleRepository;

/** The login helper service. */
private final LoginHelperService loginHelperService;

private final RegistrationGuard registrationGuard;

/** The Event Publisher. */
private final ApplicationEventPublisher eventPublisher;

/** The user role name. */
private static final String USER_ROLE_NAME = "ROLE_USER";

OidcUserService defaultOidcUserService = new OidcUserService();

/**
Expand Down Expand Up @@ -78,8 +92,11 @@ public User handleOidcLoginSuccess(String registrationId, OidcUser oidcUser) {
throw new OAuth2AuthenticationException(new OAuth2Error("Missing Email"),
"Unable to retrieve email address from " + registrationId + ". Please ensure you have granted email permissions.");
}
log.debug("handleOidcLoginSuccess: looking up user with email: {}", user.getEmail());
User existingUser = userRepository.findByEmail(user.getEmail());
// Normalize email for consistent lookup — getUserFromKeycloakOidc2User already lowercases,
// but we normalize again here defensively in case additional sources are added.
String normalizedEmail = user.getEmail().trim().toLowerCase(Locale.ROOT);
log.debug("handleOidcLoginSuccess: looking up user with email: {}", normalizedEmail);
User existingUser = userRepository.findByEmail(normalizedEmail);
log.debug("handleOidcLoginSuccess: existingUser: {}", existingUser);
if (existingUser != null && registrationId != null) {
log.debug("handleOidcLoginSuccess: existingUser.getProvider(): {}", existingUser.getProvider());
Expand All @@ -93,14 +110,14 @@ public User handleOidcLoginSuccess(String registrationId, OidcUser oidcUser) {
existingUser = updateExistingUser(existingUser, user);
return userRepository.save(existingUser);
} else {
log.debug("handleOidcLoginSuccess: registering new user with email: {}", user.getEmail());
log.debug("handleOidcLoginSuccess: registering new user with email: {}", normalizedEmail);
RegistrationDecision decision = registrationGuard.evaluate(
new RegistrationContext(user.getEmail(), RegistrationSource.OIDC, registrationId));
new RegistrationContext(normalizedEmail, RegistrationSource.OIDC, registrationId));
if (!decision.allowed()) {
log.info("Registration denied for email: {} source: OIDC provider: {} reason: {}",
user.getEmail(), registrationId, decision.reason());
normalizedEmail, registrationId, decision.reason());
throw new OAuth2AuthenticationException(
new OAuth2Error("registration_denied"), decision.reason());
new OAuth2Error("registration_denied", decision.reason(), null), decision.reason());
}
user = registerNewOidcUser(registrationId, user);
return user;
Expand All @@ -119,10 +136,14 @@ public User handleOidcLoginSuccess(String registrationId, OidcUser oidcUser) {
private User registerNewOidcUser(String registrationId, User user) {
User.Provider provider = User.Provider.valueOf(registrationId.toUpperCase());
user.setProvider(provider);
user.setRoles(Arrays.asList(roleRepository.findByName("ROLE_USER")));
// We will trust OAuth2 providers to provide us with a verified email address.
user.setRoles(Arrays.asList(roleRepository.findByName(USER_ROLE_NAME)));
// We will trust OIDC providers to provide us with a verified email address.
user.setEnabled(true);
return userRepository.save(user);
User savedUser = userRepository.save(user);
AuditEvent registrationAuditEvent = AuditEvent.builder().source(this).user(savedUser).action("OIDC Registration Success").actionStatus("Success")
.message("Registration Confirmed. User logged in.").build();
eventPublisher.publishEvent(registrationAuditEvent);
return savedUser;
}

/**
Expand Down Expand Up @@ -153,11 +174,8 @@ public User getUserFromKeycloakOidc2User(OidcUser principal) {
}
log.debug("Principal attributes: {}", principal.getAttributes());
User user = new User();
/* user.setEmail(principal.getAttribute("email"));
user.setFirstName(principal.getAttribute("given_name"));
user.setLastName(principal.getAttribute("family_name"));*/
String email = principal.getEmail();
user.setEmail(email != null ? email.toLowerCase() : null);
user.setEmail(email != null ? email.trim().toLowerCase(Locale.ROOT) : null);
user.setFirstName(principal.getGivenName());
user.setLastName(principal.getFamilyName());
user.setProvider(User.Provider.KEYCLOAK);
Expand All @@ -175,17 +193,11 @@ public User getUserFromKeycloakOidc2User(OidcUser principal) {
*/
@Override
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
log.debug("Loading user from OAuth2 provider with userRequest: {}", userRequest);
log.debug("Loading user from OIDC provider with userRequest: {}", userRequest);
OidcUser user = defaultOidcUserService.loadUser(userRequest);
String registrationId = userRequest.getClientRegistration().getRegistrationId();
log.debug("registrationId: {}", registrationId);
User dbUser = handleOidcLoginSuccess(registrationId, user);
DSUserDetails dsUserDetails = DSUserDetails.builder()
.user(dbUser)
.oidcUserInfo(user.getUserInfo())
.oidcIdToken(user.getIdToken())
.grantedAuthorities(user.getAuthorities())
.build();
return dsUserDetails;
return loginHelperService.userLoginHelper(dbUser, user.getUserInfo(), user.getIdToken());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import java.util.Collection;
import java.util.Date;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import com.digitalsanctuary.spring.user.persistence.model.User;
Expand Down Expand Up @@ -49,4 +51,24 @@ public DSUserDetails userLoginHelper(User dbUser) {
DSUserDetails userDetails = new DSUserDetails(dbUser, authorities);
return userDetails;
}

/**
* Helper method to authenticate an OIDC user after login, attaching the OIDC-specific tokens
* and claims to the principal while keeping {@link DSUserDetails} immutable.
*
* @param dbUser The user to authenticate.
* @param oidcUserInfo The OIDC user info claims.
* @param oidcIdToken The OIDC ID token.
* @return The user details object with OIDC tokens set.
*/
public DSUserDetails userLoginHelper(User dbUser, OidcUserInfo oidcUserInfo, OidcIdToken oidcIdToken) {
// Updating lastActivity date for this login
dbUser.setLastActivityDate(new Date());

// Check if the user account is locked, but should be unlocked now, and unlock it
dbUser = loginAttemptService.checkIfUserShouldBeUnlocked(dbUser);

Collection<? extends GrantedAuthority> authorities = authorityService.getAuthoritiesFromUser(dbUser);
return new DSUserDetails(dbUser, oidcUserInfo, oidcIdToken, authorities);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.digitalsanctuary.spring.user.registration;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

@DisplayName("RegistrationContext Tests")
class RegistrationContextTest {

@Test
@DisplayName("Should reject null source in RegistrationContext")
void shouldRejectNullSource() {
assertThatThrownBy(() -> new RegistrationContext("user@example.com", null, null))
.isInstanceOf(NullPointerException.class)
.hasMessageContaining("source must not be null");
}

@Test
@DisplayName("Should accept valid context with all fields")
void shouldAcceptValidContext() {
RegistrationContext context = new RegistrationContext("user@example.com", RegistrationSource.OAUTH2, "google");

assertThat(context.email()).isEqualTo("user@example.com");
assertThat(context.source()).isEqualTo(RegistrationSource.OAUTH2);
assertThat(context.providerName()).isEqualTo("google");
}

@Test
@DisplayName("Should accept null email and null providerName")
void shouldAcceptNullEmailAndProvider() {
RegistrationContext context = new RegistrationContext(null, RegistrationSource.FORM, null);

assertThat(context.email()).isNull();
assertThat(context.source()).isEqualTo(RegistrationSource.FORM);
assertThat(context.providerName()).isNull();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;

Expand All @@ -37,9 +38,15 @@ class DSOidcUserServiceRegistrationGuardTest {
@Mock
private RoleRepository roleRepository;

@Mock
private LoginHelperService loginHelperService;

@Mock
private RegistrationGuard registrationGuard;

@Mock
private ApplicationEventPublisher eventPublisher;

@InjectMocks
private DSOidcUserService service;

Expand Down
Loading
Loading