Default to current user on timer POST (#134)

This commit is contained in:
Christopher C. Wells 2020-06-19 14:23:43 -07:00
parent d6c2b48917
commit 4e5153daed
2 changed files with 19 additions and 1 deletions

View File

@ -135,13 +135,24 @@ class TimerSerializer(CoreModelSerializer):
child = serializers.PrimaryKeyRelatedField( child = serializers.PrimaryKeyRelatedField(
allow_null=True, allow_empty=True, queryset=models.Child.objects.all(), allow_null=True, allow_empty=True, queryset=models.Child.objects.all(),
required=False) required=False)
user = serializers.PrimaryKeyRelatedField(queryset=User.objects.all()) user = serializers.PrimaryKeyRelatedField(
allow_null=True, allow_empty=True, queryset=User.objects.all(),
required=False)
class Meta: class Meta:
model = models.Timer model = models.Timer
fields = ('id', 'child', 'name', 'start', 'end', 'duration', 'active', fields = ('id', 'child', 'name', 'start', 'end', 'duration', 'active',
'user') 'user')
def validate(self, attrs):
attrs = super(TimerSerializer, self).validate(attrs)
# Set user to current user if no value is provided.
if 'user' not in attrs or attrs['user'] is None:
attrs['user'] = self.context['request'].user
return attrs
class TummyTimeSerializer(CoreModelWithDurationSerializer): class TummyTimeSerializer(CoreModelWithDurationSerializer):
class Meta(CoreModelWithDurationSerializer.Meta): class Meta(CoreModelWithDurationSerializer.Meta):

View File

@ -368,6 +368,13 @@ class TimerAPITestCase(TestBase.BabyBuddyAPITestCaseBase):
obj = models.Timer.objects.get(pk=response.data['id']) obj = models.Timer.objects.get(pk=response.data['id'])
self.assertEqual(obj.name, data['name']) self.assertEqual(obj.name, data['name'])
def test_post_default_user(self):
user = User.objects.first()
response = self.client.post(self.endpoint)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
obj = models.Timer.objects.get(pk=response.data['id'])
self.assertEqual(obj.user, user)
def test_patch(self): def test_patch(self):
endpoint = '{}{}/'.format(self.endpoint, 1) endpoint = '{}{}/'.format(self.endpoint, 1)
response = self.client.get(endpoint) response = self.client.get(endpoint)