Update nouveau to only use the mmu interval notifiers.
Signed-off-by: Ralph Campbell rcampbell@nvidia.com --- drivers/gpu/drm/nouveau/nouveau_svm.c | 313 +++++++++++++++++--------- 1 file changed, 201 insertions(+), 112 deletions(-)
diff --git a/drivers/gpu/drm/nouveau/nouveau_svm.c b/drivers/gpu/drm/nouveau/nouveau_svm.c index df9bf1fd1bc0..0343e48d41d7 100644 --- a/drivers/gpu/drm/nouveau/nouveau_svm.c +++ b/drivers/gpu/drm/nouveau/nouveau_svm.c @@ -88,7 +88,7 @@ nouveau_ivmm_find(struct nouveau_svm *svm, u64 inst) }
struct nouveau_svmm { - struct mmu_notifier notifier; + struct mm_struct *mm; struct nouveau_vmm *vmm; struct { unsigned long start; @@ -98,6 +98,13 @@ struct nouveau_svmm { struct mutex mutex; };
+struct svmm_interval { + struct mmu_interval_notifier notifier; + struct nouveau_svmm *svmm; +}; + +static const struct mmu_interval_notifier_ops nouveau_svm_mni_ops; + #define SVMM_DBG(s,f,a...) \ NV_DEBUG((s)->vmm->cli->drm, "svm-%p: "f"\n", (s), ##a) #define SVMM_ERR(s,f,a...) \ @@ -236,6 +243,8 @@ nouveau_svmm_join(struct nouveau_svmm *svmm, u64 inst) static void nouveau_svmm_invalidate(struct nouveau_svmm *svmm, u64 start, u64 limit) { + SVMM_DBG(svmm, "invalidate %016llx-%016llx", start, limit); + if (limit > start) { bool super = svmm->vmm->vmm.object.client->super; svmm->vmm->vmm.object.client->super = true; @@ -248,58 +257,25 @@ nouveau_svmm_invalidate(struct nouveau_svmm *svmm, u64 start, u64 limit) } }
-static int -nouveau_svmm_invalidate_range_start(struct mmu_notifier *mn, - const struct mmu_notifier_range *update) -{ - struct nouveau_svmm *svmm = - container_of(mn, struct nouveau_svmm, notifier); - unsigned long start = update->start; - unsigned long limit = update->end; - - if (!mmu_notifier_range_blockable(update)) - return -EAGAIN; - - SVMM_DBG(svmm, "invalidate %016lx-%016lx", start, limit); - - mutex_lock(&svmm->mutex); - if (unlikely(!svmm->vmm)) - goto out; - - if (limit > svmm->unmanaged.start && start < svmm->unmanaged.limit) { - if (start < svmm->unmanaged.start) { - nouveau_svmm_invalidate(svmm, start, - svmm->unmanaged.limit); - } - start = svmm->unmanaged.limit; - } - - nouveau_svmm_invalidate(svmm, start, limit); - -out: - mutex_unlock(&svmm->mutex); - return 0; -} - -static void nouveau_svmm_free_notifier(struct mmu_notifier *mn) -{ - kfree(container_of(mn, struct nouveau_svmm, notifier)); -} - -static const struct mmu_notifier_ops nouveau_mn_ops = { - .invalidate_range_start = nouveau_svmm_invalidate_range_start, - .free_notifier = nouveau_svmm_free_notifier, -}; - void nouveau_svmm_fini(struct nouveau_svmm **psvmm) { struct nouveau_svmm *svmm = *psvmm; + struct mmu_interval_notifier *mni; + if (svmm) { mutex_lock(&svmm->mutex); + while (true) { + mni = mmu_interval_notifier_find(svmm->mm, + &nouveau_svm_mni_ops, 0UL, ~0UL); + if (!mni) + break; + mmu_interval_notifier_put(mni); + } svmm->vmm = NULL; mutex_unlock(&svmm->mutex); - mmu_notifier_put(&svmm->notifier); + mmdrop(svmm->mm); + kfree(svmm); *psvmm = NULL; } } @@ -343,11 +319,12 @@ nouveau_svmm_init(struct drm_device *dev, void *data, goto out_free;
down_write(¤t->mm->mmap_sem); - svmm->notifier.ops = &nouveau_mn_ops; - ret = __mmu_notifier_register(&svmm->notifier, current->mm); + ret = __mmu_notifier_register(NULL, current->mm); if (ret) goto out_mm_unlock; - /* Note, ownership of svmm transfers to mmu_notifier */ + + mmgrab(current->mm); + svmm->mm = current->mm;
cli->svm.svmm = svmm; cli->svm.cli = cli; @@ -482,65 +459,212 @@ nouveau_svm_fault_cache(struct nouveau_svm *svm, fault->inst, fault->addr, fault->access); }
-struct svm_notifier { - struct mmu_interval_notifier notifier; - struct nouveau_svmm *svmm; -}; +static struct svmm_interval *nouveau_svmm_new_interval( + struct nouveau_svmm *svmm, + unsigned long start, + unsigned long last) +{ + struct svmm_interval *smi; + int ret; + + smi = kmalloc(sizeof(*smi), GFP_ATOMIC); + if (!smi) + return NULL; + + smi->svmm = svmm; + + ret = mmu_interval_notifier_insert_safe(&smi->notifier, svmm->mm, + start, last - start + 1, &nouveau_svm_mni_ops); + if (ret) { + kfree(smi); + return NULL; + } + + return smi; +} + +static void nouveau_svmm_do_unmap(struct mmu_interval_notifier *mni, + const struct mmu_notifier_range *range) +{ + struct svmm_interval *smi = + container_of(mni, struct svmm_interval, notifier); + struct nouveau_svmm *svmm = smi->svmm; + unsigned long start = mmu_interval_notifier_start(mni); + unsigned long last = mmu_interval_notifier_last(mni); + + if (start >= range->start) { + /* Remove the whole interval or keep the right-hand part. */ + if (last <= range->end) + mmu_interval_notifier_put(mni); + else + mmu_interval_notifier_update(mni, range->end, last); + return; + } + + /* Keep the left-hand part of the interval. */ + mmu_interval_notifier_update(mni, start, range->start - 1); + + /* If a hole is created, create an interval for the right-hand part. */ + if (last >= range->end) { + smi = nouveau_svmm_new_interval(svmm, range->end, last); + /* + * If we can't allocate an interval, we won't get invalidation + * callbacks so clear the mapping and rely on faults to reload + * the mappings if needed. + */ + if (!smi) + nouveau_svmm_invalidate(svmm, range->end, last + 1); + } +}
-static bool nouveau_svm_range_invalidate(struct mmu_interval_notifier *mni, - const struct mmu_notifier_range *range, - unsigned long cur_seq) +static bool nouveau_svmm_interval_invalidate(struct mmu_interval_notifier *mni, + const struct mmu_notifier_range *range, + unsigned long cur_seq) { - struct svm_notifier *sn = - container_of(mni, struct svm_notifier, notifier); + struct svmm_interval *smi = + container_of(mni, struct svmm_interval, notifier); + struct nouveau_svmm *svmm = smi->svmm;
/* - * serializes the update to mni->invalidate_seq done by caller and + * Serializes the update to mni->invalidate_seq done by the caller and * prevents invalidation of the PTE from progressing while HW is being - * programmed. This is very hacky and only works because the normal - * notifier that does invalidation is always called after the range - * notifier. + * programmed. */ if (mmu_notifier_range_blockable(range)) - mutex_lock(&sn->svmm->mutex); - else if (!mutex_trylock(&sn->svmm->mutex)) + mutex_lock(&svmm->mutex); + else if (!mutex_trylock(&svmm->mutex)) return false; + mmu_interval_set_seq(mni, cur_seq); - mutex_unlock(&sn->svmm->mutex); + nouveau_svmm_invalidate(svmm, range->start, range->end); + + /* Stop tracking the range if it is an unmap. */ + if (range->event == MMU_NOTIFY_UNMAP) + nouveau_svmm_do_unmap(mni, range); + + mutex_unlock(&svmm->mutex); return true; }
+static void nouveau_svmm_interval_release(struct mmu_interval_notifier *mni) +{ + struct svmm_interval *smi = + container_of(mni, struct svmm_interval, notifier); + + kfree(smi); +} + static const struct mmu_interval_notifier_ops nouveau_svm_mni_ops = { - .invalidate = nouveau_svm_range_invalidate, + .invalidate = nouveau_svmm_interval_invalidate, + .release = nouveau_svmm_interval_release, };
+/* + * Find or create a mmu_interval_notifier for the given range. + * Although mmu_interval_notifier_insert_safe() can handle overlapping + * intervals, we only create non-overlapping intervals, shrinking the hmm_range + * if it spans more than one svmm_interval. + */ +static int nouveau_svmm_interval_find(struct nouveau_svmm *svmm, + struct hmm_range *range) +{ + struct mmu_interval_notifier *mni; + struct svmm_interval *smi; + struct vm_area_struct *vma; + unsigned long start = range->start; + unsigned long last = range->end - 1; + int ret; + + mutex_lock(&svmm->mutex); + mni = mmu_interval_notifier_find(svmm->mm, &nouveau_svm_mni_ops, start, + last); + if (mni) { + if (start >= mmu_interval_notifier_start(mni)) { + smi = container_of(mni, struct svmm_interval, notifier); + if (last > mmu_interval_notifier_last(mni)) + range->end = + mmu_interval_notifier_last(mni) + 1; + goto found; + } + WARN_ON(last <= mmu_interval_notifier_start(mni)); + range->end = mmu_interval_notifier_start(mni); + last = range->end - 1; + } + /* + * Might as well create an interval covering the underlying VMA to + * avoid having to create a bunch of small intervals. + */ + vma = find_vma(svmm->mm, range->start); + if (!vma || start < vma->vm_start) { + ret = -ENOENT; + goto err; + } + if (range->end > vma->vm_end) { + range->end = vma->vm_end; + last = range->end - 1; + } else if (!mni) { + /* Anything registered on the right part of the vma? */ + mni = mmu_interval_notifier_find(svmm->mm, &nouveau_svm_mni_ops, + range->end, vma->vm_end - 1); + if (mni) + last = mmu_interval_notifier_start(mni) - 1; + else + last = vma->vm_end - 1; + } + /* Anything registered on the left part of the vma? */ + mni = mmu_interval_notifier_find(svmm->mm, &nouveau_svm_mni_ops, + vma->vm_start, start - 1); + if (mni) + start = mmu_interval_notifier_last(mni) + 1; + else + start = vma->vm_start; + smi = nouveau_svmm_new_interval(svmm, start, last); + if (!smi) { + ret = -ENOMEM; + goto err; + } + +found: + range->notifier = &smi->notifier; + mutex_unlock(&svmm->mutex); + return 0; + +err: + mutex_unlock(&svmm->mutex); + return ret; +} + static int nouveau_range_fault(struct nouveau_svmm *svmm, struct nouveau_drm *drm, void *data, u32 size, - u64 *pfns, struct svm_notifier *notifier) + u64 *pfns, u64 start, u64 end) { unsigned long timeout = jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT); /* Have HMM fault pages within the fault window to the GPU. */ struct hmm_range range = { - .notifier = ¬ifier->notifier, - .start = notifier->notifier.interval_tree.start, - .end = notifier->notifier.interval_tree.last + 1, + .start = start, + .end = end, .pfns = pfns, .flags = nouveau_svm_pfn_flags, .values = nouveau_svm_pfn_values, + .default_flags = 0, + .pfn_flags_mask = ~0UL, .pfn_shift = NVIF_VMM_PFNMAP_V0_ADDR_SHIFT, }; - struct mm_struct *mm = notifier->notifier.mm; + struct mm_struct *mm = svmm->mm; long ret;
while (true) { if (time_after(jiffies, timeout)) return -EBUSY;
- range.notifier_seq = mmu_interval_read_begin(range.notifier); - range.default_flags = 0; - range.pfn_flags_mask = -1UL; down_read(&mm->mmap_sem); + ret = nouveau_svmm_interval_find(svmm, &range); + if (ret) { + up_read(&mm->mmap_sem); + return ret; + } + range.notifier_seq = mmu_interval_read_begin(range.notifier); ret = hmm_range_fault(&range, 0); up_read(&mm->mmap_sem); if (ret <= 0) { @@ -585,7 +709,6 @@ nouveau_svm_fault(struct nvif_notify *notify) } i; u64 phys[16]; } args; - struct vm_area_struct *vma; u64 inst, start, limit; int fi, fn, pi, fill; int replay = 0, ret; @@ -640,7 +763,6 @@ nouveau_svm_fault(struct nvif_notify *notify) args.i.p.version = 0;
for (fi = 0; fn = fi + 1, fi < buffer->fault_nr; fi = fn) { - struct svm_notifier notifier; struct mm_struct *mm;
/* Cancel any faults from non-SVM channels. */ @@ -662,36 +784,12 @@ nouveau_svm_fault(struct nvif_notify *notify) start = max_t(u64, start, svmm->unmanaged.limit); SVMM_DBG(svmm, "wndw %016llx-%016llx", start, limit);
- mm = svmm->notifier.mm; + mm = svmm->mm; if (!mmget_not_zero(mm)) { nouveau_svm_fault_cancel_fault(svm, buffer->fault[fi]); continue; }
- /* Intersect fault window with the CPU VMA, cancelling - * the fault if the address is invalid. - */ - down_read(&mm->mmap_sem); - vma = find_vma_intersection(mm, start, limit); - if (!vma) { - SVMM_ERR(svmm, "wndw %016llx-%016llx", start, limit); - up_read(&mm->mmap_sem); - mmput(mm); - nouveau_svm_fault_cancel_fault(svm, buffer->fault[fi]); - continue; - } - start = max_t(u64, start, vma->vm_start); - limit = min_t(u64, limit, vma->vm_end); - up_read(&mm->mmap_sem); - SVMM_DBG(svmm, "wndw %016llx-%016llx", start, limit); - - if (buffer->fault[fi]->addr != start) { - SVMM_ERR(svmm, "addr %016llx", buffer->fault[fi]->addr); - mmput(mm); - nouveau_svm_fault_cancel_fault(svm, buffer->fault[fi]); - continue; - } - /* Prepare the GPU-side update of all pages within the * fault window, determining required pages and access * permissions based on pending faults. @@ -743,18 +841,9 @@ nouveau_svm_fault(struct nvif_notify *notify) args.i.p.addr, args.i.p.addr + args.i.p.size, fn - fi);
- notifier.svmm = svmm; - ret = mmu_interval_notifier_insert(¬ifier.notifier, - svmm->notifier.mm, - args.i.p.addr, args.i.p.size, - &nouveau_svm_mni_ops); - if (!ret) { - ret = nouveau_range_fault( - svmm, svm->drm, &args, - sizeof(args.i) + pi * sizeof(args.phys[0]), - args.phys, ¬ifier); - mmu_interval_notifier_remove(¬ifier.notifier); - } + ret = nouveau_range_fault(svmm, svm->drm, &args, + sizeof(args.i) + pi * sizeof(args.phys[0]), args.phys, + start, start + args.i.p.size); mmput(mm);
/* Cancel any faults in the window whose pages didn't manage