Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions contentcuration/contentcuration/db/models/expressions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from django.db.models import BooleanField
from django.db.models import F
from django.db.models import Q
from django.db.models import Value
from django.db.models.expressions import CombinedExpression
from django.db.models.expressions import Func
from django.db.models.sql.where import WhereNode
Expand Down Expand Up @@ -28,6 +30,20 @@ class BooleanComparison(CombinedExpression):
output_field = BooleanField()


class IsNull(BooleanComparison):
"""
An expression that results in a Boolean value, useful for annotating
if a column IS or IS NOT NULL

Example:
IsNull('my_field_name') -> my_field_name IS NULL
IsNull('my_field_name', negate=True) -> my_field_name IS NOT NULL
"""
def __init__(self, field_name, negate=False):
operator = 'IS NOT' if negate else 'IS'
super(IsNull, self).__init__(F(field_name), operator, Value(None))


class Array(Func):
"""
Create an array datatype within Postgres.
Expand Down
104 changes: 104 additions & 0 deletions contentcuration/contentcuration/db/models/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,110 @@
from django.db import connections
from django.db.models.expressions import Col
from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.constants import INNER
from django.db.models.sql.query import Query
from django_cte import CTEQuerySet
from django_cte import With as CTEWith
from mptt.querysets import TreeQuerySet


RIGHT_JOIN = 'RIGHT JOIN'


class CustomTreeQuerySet(TreeQuerySet, CTEQuerySet):
pass


class With(CTEWith):
"""
Custom CTE class which allows more join types than just INNER and LOUTER (LEFT)
"""
def join(self, model_or_queryset, *filter_q, **filter_kw):
"""
Slight hack to allow more join types
"""
join_type = filter_kw.get('_join_type', INNER)
queryset = super(With, self).join(model_or_queryset, *filter_q, **filter_kw)

# the underlying Django code forces the join type into INNER or a LEFT OUTER join
alias, _ = queryset.query.table_alias(self.name)
join = queryset.query.alias_map[alias]
if join.join_type != join_type:
join.join_type = join_type
return queryset


class WithValues(With):
"""
Allows the creation of a CTE that holds a VALUES list

@see https://www.postgresql.org/docs/9.6/queries-values.html
"""
def __init__(self, fields, values_list, name="cte"):
super(WithValues, self).__init__(None, name=name)
self.query = WithValuesQuery(self)
self.fields = fields
self.values_list = values_list

def _resolve_ref(self, name):
"""
Gets called when a column reference is accessed via the CTE instance `.col.name`
"""
if name not in self.fields:
raise RuntimeError("No field with name `{}`".format(name))

field = self.fields.get(name)
field.set_attributes_from_name(name)
return Col(self.name, field, output_field=field)


class WithValuesSQLCompiler(SQLCompiler):
TEMPLATE = "SELECT * FROM (VALUES {values_statement}) AS {cte_name}({fields_statement})"

def as_sql(self, with_limits=True, with_col_aliases=False):
"""
Ideally this would return something like:
WITH t_cte(fieldA, fieldB) AS (VALUES (), ...)
But django_cte doesn't give us a way to do that, so we do this instead:
WITH t_cte AS (SELECT * FROM (VALUES (), ...) AS _t_cte(fieldA, fieldB)))

:return: A tuple of SQL and parameters
"""
value_parameters = ", ".join(["%s"] * len(self.cte.fields))
values_statement = ", ".join(["({})".format(value_parameters)] * len(self.cte.values_list))
fields_statement = ", ".join([self.connection.ops.quote_name(field) for field in list(self.cte.fields)])
sql = self.TEMPLATE.format(
values_statement=values_statement,
cte_name="_{}".format(self.cte.name),
fields_statement=fields_statement
)
return sql, list(sum(self.cte.values_list, ()))

@property
def cte(self):
"""
:rtype: WithValues
"""
return self.query.cte


class WithValuesQuery(Query):
"""
Dedicated query class for creating a CTE

Note: this does inherit from Query, which we're not passing a Model instance so not all Query
functionality is intended to work
"""
def __init__(self, cte):
super(WithValuesQuery, self).__init__(None)
self.cte = cte

def get_compiler(self, using=None, connection=None):
"""
This code is modeled after Query.get_compiler()
"""
if using is None and connection is None:
raise ValueError("Need either using or connection")
if using:
connection = connections[using]
return WithValuesSQLCompiler(self, connection, using)
1 change: 1 addition & 0 deletions contentcuration/contentcuration/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,7 @@ class ContentTag(models.Model):
id = UUIDField(primary_key=True, default=uuid.uuid4)
tag_name = models.CharField(max_length=50)
channel = models.ForeignKey('Channel', related_name='tags', blank=True, null=True, db_index=True, on_delete=models.SET_NULL)
objects = CustomManager()

def __str__(self):
return self.tag_name
Expand Down
129 changes: 84 additions & 45 deletions contentcuration/contentcuration/viewsets/contentnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from django.conf import settings
from django.db import IntegrityError
from django.db import models
from django.db.models import Exists
from django.db.models import F
from django.db.models import OuterRef
Expand All @@ -12,6 +13,7 @@
from django.db.models.functions import Coalesce
from django.http import Http404
from django.utils.timezone import now
from django_cte import CTEQuerySet
from django_filters.rest_framework import CharFilter
from django_filters.rest_framework import DjangoFilterBackend
from django_filters.rest_framework import UUIDFilter
Expand All @@ -28,13 +30,18 @@
from rest_framework.serializers import ValidationError
from rest_framework.viewsets import ViewSet

from contentcuration.db.models.expressions import IsNull
from contentcuration.db.models.query import RIGHT_JOIN
from contentcuration.db.models.query import With
from contentcuration.db.models.query import WithValues
from contentcuration.models import AssessmentItem
from contentcuration.models import Channel
from contentcuration.models import ContentNode
from contentcuration.models import ContentTag
from contentcuration.models import File
from contentcuration.models import generate_storage_url
from contentcuration.models import PrerequisiteContentRelationship
from contentcuration.models import UUIDField
from contentcuration.tasks import create_async_task
from contentcuration.viewsets.base import BulkListSerializer
from contentcuration.viewsets.base import BulkModelSerializer
Expand Down Expand Up @@ -115,60 +122,92 @@ def filter__node_id_channel_id(self, queryset, name, value):
return queryset.filter(query)


def bulk_create_tag_relations(tags_relations_to_create):
if tags_relations_to_create:
# In Django 2.2 add ignore_conflicts to make this fool proof
try:
ContentNode.tags.through.objects.bulk_create(tags_relations_to_create)
except IntegrityError:
# One of the relations already exists, so just save them one by one.
# Django's default upsert behaviour should mean we get no errors this way
for to_create in tags_relations_to_create:
to_create.save()
tags_values_cte_fields = {
'tag': models.CharField(),
'node_id': UUIDField()
}


def set_tags(tags_by_id):
all_tag_names = set()
tags_relations_to_create = []
tag_tuples = []
tags_relations_to_delete = []

# put all tags into a tuple (tag_name, node_id) to send into SQL
for target_node_id, tag_names in tags_by_id.items():
for tag_name, value in tag_names.items():
if value:
all_tag_names.add(tag_name)

# channel is no longer used on the tag object, so don't bother using it
available_tags = set(
ContentTag.objects.filter(
tag_name__in=all_tag_names, channel__isnull=True
).values_list("tag_name", flat=True)
tag_tuples.append((tag_name, target_node_id))

# create CTE that holds the tag_tuples data
values_cte = WithValues(tags_values_cte_fields, tag_tuples, name='values_cte')

# create another CTE which will RIGHT join against the tag table, so we get all of our
# tag_tuple data back, plus the tag_id if it exists. Ideally we wouldn't normally use a RIGHT
# join, we would simply swap the tables and do a LEFT, but with the VALUES CTE
# that isn't possible
tags_qs = (
values_cte.join(ContentTag, tag_name=values_cte.col.tag, _join_type=RIGHT_JOIN)
.annotate(
tag=values_cte.col.tag,
node_id=values_cte.col.node_id,
tag_id=F('id'),
)
.values('tag', 'node_id', 'tag_id')
)
tags_cte = With(tags_qs, name='tags_cte')

# the final query, we RIGHT join against the tag relation table so we get the tag_tuple back
# again, plus the tag_id from the previous CTE, plus annotate a boolean of whether
# the relation exists
qs = (
tags_cte.join(
CTEQuerySet(model=ContentNode.tags.through),
contenttag_id=tags_cte.col.tag_id,
contentnode_id=tags_cte.col.node_id,
_join_type=RIGHT_JOIN
)
.with_cte(values_cte)
.with_cte(tags_cte)
.annotate(
tag_name=tags_cte.col.tag,
node_id=tags_cte.col.node_id,
tag_id=tags_cte.col.tag_id,
has_relation=IsNull('contentnode_id', negate=True)
)
.values('tag_name', 'node_id', 'tag_id', 'has_relation')
)

tags_to_create = all_tag_names.difference(available_tags)

new_tags = [ContentTag(tag_name=tag_name) for tag_name in tags_to_create]
ContentTag.objects.bulk_create(new_tags)

tag_id_by_tag_name = {
t["tag_name"]: t["id"]
for t in ContentTag.objects.filter(
tag_name__in=all_tag_names, channel__isnull=True
).values("tag_name", "id")
}

for target_node_id, tag_names in tags_by_id.items():
for tag_name, value in tag_names.items():
if value:
tag_id = tag_id_by_tag_name[tag_name]
tags_relations_to_create.append(
ContentNode.tags.through(
contentnode_id=target_node_id, contenttag_id=tag_id
)
)
created_tags = {}
for result in qs:
tag_name = result["tag_name"]
node_id = result["node_id"]
tag_id = result["tag_id"]
has_relation = result["has_relation"]

tags = tags_by_id[node_id]
value = tags[tag_name]

# tag wasn't found in the DB, but we're adding it to the node, so create it
if not tag_id and value:
# keep a cache of created tags during the session
if tag_name in created_tags:
tag_id = created_tags[tag_name]
else:
tags_relations_to_delete.append(
Q(contentnode_id=target_node_id, contenttag__tag_name=tag_name)
)
bulk_create_tag_relations(tags_relations_to_create)
tag, _ = ContentTag.objects.get_or_create(tag_name=tag_name, channel_id=None)
tag_id = tag.pk
created_tags.update({tag_name: tag_id})

# if we're adding the tag but the relation didn't exist, create it now, otherwise
# track the tag as one relation we should delete
if value and not has_relation:
ContentNode.tags.through.objects.get_or_create(
contentnode_id=node_id, contenttag_id=tag_id
)
elif not value and has_relation:
tags_relations_to_delete.append(
Q(contentnode_id=node_id, contenttag_id=tag_id)
)

# delete tags
if tags_relations_to_delete:
ContentNode.tags.through.objects.filter(
reduce(lambda x, y: x | y, tags_relations_to_delete)
Expand Down