Refactor API key reset as part of settings form

This adds core CSRF protection to the reset functionality.
This commit is contained in:
Christopher C. Wells 2021-06-21 21:27:45 -07:00
parent cca1e409e3
commit 1689bc8e20
5 changed files with 22 additions and 26 deletions

View File

@ -81,12 +81,12 @@
<label for="id_email" class="col-sm-2 col-form-label">{% trans "Key" %}</label> <label for="id_email" class="col-sm-2 col-form-label">{% trans "Key" %}</label>
<div class="col-sm-10"> <div class="col-sm-10">
<samp>{{ user.settings.api_key }}</samp> <samp>{{ user.settings.api_key }}</samp>
<a class="btn btn-xs btn-danger" href="{% url 'babybuddy:user-reset-api-key' %}">{% trans "Regenerate" %}</a> <input type="submit" name="api_key_regenerate" value="{% trans "Regenerate" %}" class="btn btn-danger btn-xs" />
</div> </div>
</div> </div>
</fieldset> </fieldset>
<input type="hidden" name="next" value="{% url 'babybuddy:user-settings' %}" /> <input type="hidden" name="next" value="{% url 'babybuddy:user-settings' %}" />
<button type="submit" class="btn btn-primary">{% trans "Submit" %}</button> <input type="submit" name="save_settings" value="{% trans "Submit" %}" class="btn btn-primary">
</form> </form>
</div> </div>
{% endblock %} {% endblock %}

View File

@ -104,6 +104,21 @@ class FormsTestCase(TestCase):
self.assertEqual(page.status_code, 200) self.assertEqual(page.status_code, 200)
self.assertContains(page, 'New First Name') self.assertContains(page, 'New First Name')
def test_user_regenerate_api_key(self):
self.c.login(**self.credentials)
api_key_before = User.objects.get(pk=self.user.id).settings.api_key()
params = self.settings_template.copy()
params['api_key_regenerate'] = 'Regenerate'
page = self.c.post('/user/settings/', params, follow=True)
self.assertEqual(page.status_code, 200)
self.assertNotEqual(
api_key_before,
User.objects.get(pk=self.user.id).settings.api_key()
)
def test_user_settings_invalid(self): def test_user_settings_invalid(self):
self.c.login(**self.credentials) self.c.login(**self.credentials)

View File

@ -45,15 +45,6 @@ class ViewsTestCase(TestCase):
self.assertNotEqual(session1, session2) self.assertNotEqual(session1, session2)
self.assertEqual(session2, session3) self.assertEqual(session2, session3)
def test_user_reset_api_key(self):
api_key_before = User.objects.get(pk=self.user.id).settings.api_key()
page = self.c.get('/user/reset-api-key/')
self.assertEqual(page.status_code, 302)
self.assertNotEqual(
api_key_before,
User.objects.get(pk=self.user.id).settings.api_key()
)
def test_user_settings(self): def test_user_settings(self):
page = self.c.get('/user/settings/') page = self.c.get('/user/settings/')
self.assertEqual(page.status_code, 200) self.assertEqual(page.status_code, 200)

View File

@ -37,11 +37,6 @@ app_patterns = [
views.UserPassword.as_view(), views.UserPassword.as_view(),
name='user-password' name='user-password'
), ),
path(
'user/reset-api-key/',
views.UserResetAPIKey.as_view(),
name='user-reset-api-key'
),
path( path(
'user/settings/', 'user/settings/',
views.UserSettings.as_view(), views.UserSettings.as_view(),

View File

@ -103,16 +103,6 @@ class UserPassword(LoginRequiredMixin, View):
return render(request, self.template_name, {'form': form}) return render(request, self.template_name, {'form': form})
class UserResetAPIKey(LoginRequiredMixin, View):
"""
Resets the API key of the logged in user.
"""
def get(self, request):
request.user.settings.api_key(reset=True)
messages.success(request, _('User API key regenerated.'))
return redirect('babybuddy:user-settings')
class UserSettings(LoginRequiredMixin, View): class UserSettings(LoginRequiredMixin, View):
""" """
Handles both the User and Settings models. Handles both the User and Settings models.
@ -130,6 +120,11 @@ class UserSettings(LoginRequiredMixin, View):
}) })
def post(self, request): def post(self, request):
if request.POST.get('api_key_regenerate'):
request.user.settings.api_key(reset=True)
messages.success(request, _('User API key regenerated.'))
return redirect('babybuddy:user-settings')
form_user = self.form_user_class( form_user = self.form_user_class(
instance=request.user, instance=request.user,
data=request.POST) data=request.POST)