diff --git a/api/views.py b/api/views.py index ca7b1ca2..1af49d46 100644 --- a/api/views.py +++ b/api/views.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from rest_framework import viewsets from core.models import Baby, DiaperChange, Feeding, Sleep, TummyTime +from core.utils import filter_by_params from .serializers import (BabySerializer, DiaperChangeSerializer, FeedingSerializer, SleepSerializer, TummyTimeSerializer,) @@ -14,14 +15,8 @@ class BabyViewSet(viewsets.ModelViewSet): serializer_class = BabySerializer def get_queryset(self): - queryset = Baby.objects.all() - - for param in ['first_name', 'last_name']: - value = self.request.query_params.get(param, None) - if value is not None: - queryset = queryset.filter(**{param: value}) - - return queryset + params = ['first_name', 'last_name'] + return filter_by_params(self.request, Baby, params) class DiaperChangeViewSet(viewsets.ModelViewSet): @@ -29,14 +24,8 @@ class DiaperChangeViewSet(viewsets.ModelViewSet): serializer_class = DiaperChangeSerializer def get_queryset(self): - queryset = DiaperChange.objects.all() - - for param in ['baby__last_name', 'wet', 'solid', 'color']: - value = self.request.query_params.get(param, None) - if value is not None: - queryset = queryset.filter(**{param: value}) - - return queryset + params = ['baby__last_name', 'wet', 'solid', 'color'] + return filter_by_params(self.request, DiaperChange, params) class FeedingViewSet(viewsets.ModelViewSet): @@ -44,14 +33,8 @@ class FeedingViewSet(viewsets.ModelViewSet): serializer_class = FeedingSerializer def get_queryset(self): - queryset = Feeding.objects.all() - - for param in ['baby__last_name', 'type', 'method']: - value = self.request.query_params.get(param, None) - if value is not None: - queryset = queryset.filter(**{param: value}) - - return queryset + params = ['baby__last_name', 'type', 'method'] + return filter_by_params(self.request, Feeding, params) class SleepViewSet(viewsets.ModelViewSet): @@ -59,14 +42,8 @@ class SleepViewSet(viewsets.ModelViewSet): serializer_class = SleepSerializer def get_queryset(self): - queryset = Sleep.objects.all() - - for param in ['baby__last_name']: - value = self.request.query_params.get(param, None) - if value is not None: - queryset = queryset.filter(**{param: value}) - - return queryset + params = ['baby__last_name'] + return filter_by_params(self.request, Sleep, params) class TummyTimeViewSet(viewsets.ModelViewSet): @@ -74,11 +51,5 @@ class TummyTimeViewSet(viewsets.ModelViewSet): serializer_class = TummyTimeSerializer def get_queryset(self): - queryset = TummyTime.objects.all() - - for param in ['baby__last_name']: - value = self.request.query_params.get(param, None) - if value is not None: - queryset = queryset.filter(**{param: value}) - - return queryset + params = ['baby__last_name'] + return filter_by_params(self.request, TummyTime, params) diff --git a/core/utils.py b/core/utils.py index 69ac464e..7d89d278 100644 --- a/core/utils.py +++ b/core/utils.py @@ -21,3 +21,14 @@ def duration_string(start, end): '' if duration is '' else ', ', s, 's' if s > 1 else '') return duration + + +def filter_by_params(request, model, available_params): + queryset = model.objects.all() + + for param in available_params: + value = request.query_params.get(param, None) + if value is not None: + queryset = queryset.filter(**{param: value}) + + return queryset