Krita Source Code Documentation
Loading...
Searching...
No Matches
kis_cubic_curve_spline.h
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: 2005 C. Boemann <cbo@boemann.dk>
3 * SPDX-FileCopyrightText: 2009 Dmitry Kazakov <dimula73@gmail.com>
4 * SPDX-FileCopyrightText: 2010 Cyrille Berger <cberger@cberger.net>
5 * SPDX-FileCopyrightText: 2024 Deif Lou <ginoba@gmail.com>
6 *
7 * SPDX-License-Identifier: GPL-2.0-or-later
8 */
9
10#ifndef _KIS_CUBIC_CURVE_SPLINE_H_
11#define _KIS_CUBIC_CURVE_SPLINE_H_
12
13#include <QVector>
14#include <QList>
15
16#include <Eigen/Sparse>
17
18#include <kis_assert.h>
19
20template <typename T>
22{
23 /*
24 * e.g.
25 * |b0 c0 0 0 0| |x0| |f0|
26 * |a0 b1 c1 0 0| |x1| |f1|
27 * |0 a1 b2 c2 0|*|x2|=|f2|
28 * |0 0 a2 b3 c3| |x3| |f3|
29 * |0 0 0 a3 b4| |x4| |f4|
30 */
31
32public:
33
37 static
39 QList<T> &b,
40 QList<T> &c,
41 QList<T> &f) {
42 QVector<T> x;
43 QVector<T> alpha;
44 QVector<T> beta;
45
46 int i;
47 int size = b.size();
48
49 Q_ASSERT(a.size() == size - 1 &&
50 c.size() == size - 1 &&
51 f.size() == size);
52
53 x.resize(size);
54
59 if (size == 1) {
60 x[0] = f[0] / b[0];
61 return x;
62 }
63
68 alpha.resize(size);
69 beta.resize(size);
70
71
72 alpha[1] = -c[0] / b[0];
73 beta[1] = f[0] / b[0];
74
75 for (i = 1; i < size - 1; i++) {
76 alpha[i+1] = -c[i] /
77 (a[i-1] * alpha[i] + b[i]);
78
79 beta[i+1] = (f[i] - a[i-1] * beta[i])
80 /
81 (a[i-1] * alpha[i] + b[i]);
82 }
83
84 x.last() = (f.last() - a.last() * beta.last())
85 /
86 (b.last() + a.last() * alpha.last());
87
88 for (i = size - 2; i >= 0; i--)
89 x[i] = alpha[i+1] * x[i+1] + beta[i+1];
90
91 return x;
92 }
93};
94
95template <typename T_point, typename T>
97{
108protected:
113
117 int m_intervals {0};
118
119public:
124
132 int intervals = m_intervals = a.size() - 1;
133 int i;
134 m_begin = a.first().x();
135 m_end = a.last().x();
136
137 m_a.clear();
138 m_b.resize(intervals);
139 m_c.clear();
140 m_d.resize(intervals);
141 m_h.resize(intervals);
142
143 for (i = 0; i < intervals; i++) {
144 m_h[i] = a[i+1].x() - a[i].x();
145 m_a.append(a[i].y());
146 }
147 m_a.append(a.last().y());
148
149
150 QList<T> tri_b;
151 QList<T> tri_f;
152 QList<T> tri_a; /* equals to @tri_c */
153
154 for (i = 0; i < intervals - 1; i++) {
155 tri_b.append(2.*(m_h[i] + m_h[i+1]));
156
157 tri_f.append(6.*((m_a[i+2] - m_a[i+1]) / m_h[i+1] - (m_a[i+1] - m_a[i]) / m_h[i]));
158 }
159 for (i = 1; i < intervals - 1; i++)
160 tri_a.append(m_h[i]);
161
162 if (intervals > 1) {
163 m_c = KisTridiagonalSystem<T>::calculate(tri_a, tri_b, tri_a, tri_f);
164 }
165 m_c.prepend(0);
166 m_c.append(0);
167
168 for (i = 0; i < intervals; i++)
169 m_d[i] = (m_c[i+1] - m_c[i]) / m_h[i];
170
171 for (i = 0; i < intervals; i++)
172 m_b[i] = -0.5 * (m_c[i] * m_h[i]) - (1 / 6.0) * (m_d[i] * m_h[i] * m_h[i]) + (m_a[i+1] - m_a[i]) / m_h[i];
173 }
174
178 T getValue(T x) const {
179 T x0;
180 int i = findRegion(x, x0);
181 /* TODO: check for asm equivalent */
182 return m_a[i] +
183 m_b[i] *(x - x0) +
184 0.5 * m_c[i] *(x - x0) *(x - x0) +
185 (1 / 6.0)* m_d[i] *(x - x0) *(x - x0) *(x - x0);
186 }
187
188 T begin() const {
189 return m_begin;
190 }
191
192 T end() const {
193 return m_end;
194 }
195
196protected:
197
203 int findRegion(T x, T &x0) const {
204 int i;
205 x0 = m_begin;
206 for (i = 0; i < m_intervals; i++) {
207 if (x >= x0 && x < x0 + m_h[i])
208 return i;
209 x0 += m_h[i];
210 }
211 if (x >= x0) {
212 x0 -= m_h[m_intervals-1];
213 return m_intervals - 1;
214 }
215
216 qDebug("X value: %f\n", x);
217 qDebug("m_begin: %f\n", m_begin);
218 qDebug("m_end : %f\n", m_end);
219 Q_ASSERT_X(0, "findRegion", "X value is outside regions");
220 /* **never reached** */
221 return -1;
222 }
223};
224
225template <typename T_point, typename T>
227{
228public:
231 createSpline(a);
232 }
233
241 KIS_SAFE_ASSERT_RECOVER_RETURN(a.size() > 0);
242
243 const int intervals = a.size() - 1;
244 m_points = a;
245
246 m_coefficients.clear();
247
248 if (a.size() == 1) {
249 // Constant function
250 m_coefficients.append({ 0.0, 0.0, 0.0, a.first().y() });
251 return;
252 }
253
254 if (a.size() == 2) {
255 // Linear function
256 const T c = (a.last().y() - a.first().y()) / (a.last().x() - a.first().x());
257 const T d = a.first().y() - c * a.first().x();
258 m_coefficients.append({ 0.0, 0.0, c, d });
259 return;
260 }
261
262 using Triplet = Eigen::Triplet<qreal>;
263 using Matrix = Eigen::SparseMatrix<qreal>;
264 using Vector = Eigen::VectorXd;
265
266 const int numberOfRows = intervals * 4;
267 const int numberOfColumns = numberOfRows;
268 std::vector<Triplet> triplets;
269 Matrix A(numberOfRows, numberOfColumns);
270 Vector b(numberOfRows);
271
272 // Fill the triplet list
273 triplets.reserve(numberOfRows * 4);
274 qint32 row = 0;
275 // Fill rows with position equations
276 // Initialize the values for the left point of the first interval. The
277 // rest of the left points of the intervals use the values computed for
278 // the right point of the previous interval
279 T pointX = a.first().x();
280 T pointY = a.first().y();
281 T pointXSquared = pointX * pointX;
282 T pointXCubed = pointXSquared * pointX;
283 for (qint32 i = 0; i < intervals; ++i) {
284 const int baseColumn = i * 4;
285 // Left point
286 triplets.push_back(Triplet(row, baseColumn + 0, pointXCubed));
287 triplets.push_back(Triplet(row, baseColumn + 1, pointXSquared));
288 triplets.push_back(Triplet(row, baseColumn + 2, pointX));
289 triplets.push_back(Triplet(row, baseColumn + 3, 1.0));
290 b(row) = pointY;
291 ++row;
292 // Right point (the following values are reused for the left point
293 // of the next interval)
294 pointX = a[i + 1].x();
295 pointY = a[i + 1].y();
296 pointXSquared = pointX * pointX;
297 pointXCubed = pointXSquared * pointX;
298 triplets.push_back(Triplet(row, baseColumn + 0, pointXCubed));
299 triplets.push_back(Triplet(row, baseColumn + 1, pointXSquared));
300 triplets.push_back(Triplet(row, baseColumn + 2, pointX));
301 triplets.push_back(Triplet(row, baseColumn + 3, 1.0));
302 b(row) = pointY;
303 ++row;
304 }
305 // Fill rows with derivative equations
306 // Extreme knots second derivatives
307 pointX = a.first().x();
308 triplets.push_back(Triplet(row, 0, 6.0 * pointX));
309 triplets.push_back(Triplet(row, 1, 2.0));
310 b(row) = 0.0;
311 ++row;
312 pointX = a.last().x();
313 triplets.push_back(Triplet(row, numberOfColumns - 4, 6.0 * pointX));
314 triplets.push_back(Triplet(row, numberOfColumns - 3, 2.0));
315 b(row) = 0.0;
316 ++row;
317 // Interior knots derivatives
318 for (qint32 i = 1; i < a.size() - 1; ++i) {
319 pointX = a[i].x();
320 const qint32 baseColumn = i * 4;
321 if (a[i].isSetAsCorner()) {
322 triplets.push_back(Triplet(row, baseColumn - 4, 6.0 * pointX));
323 triplets.push_back(Triplet(row, baseColumn - 3, 2.0));
324 b(row) = 0.0;
325 ++row;
326 triplets.push_back(Triplet(row, baseColumn + 0, 6.0 * pointX));
327 triplets.push_back(Triplet(row, baseColumn + 1, 2.0));
328 b(row) = 0.0;
329 ++row;
330 } else {
331 pointXSquared = pointX * pointX;
332 // First derivatives
333 triplets.push_back(Triplet(row, baseColumn - 4, 3.0 * pointXSquared));
334 triplets.push_back(Triplet(row, baseColumn - 3, 2.0 * pointX));
335 triplets.push_back(Triplet(row, baseColumn - 2, 1.0));
336 triplets.push_back(Triplet(row, baseColumn + 0, -3.0 * pointXSquared));
337 triplets.push_back(Triplet(row, baseColumn + 1, -2.0 * pointX));
338 triplets.push_back(Triplet(row, baseColumn + 2, -1.0));
339 b(row) = 0.0;
340 ++row;
341 // Second derivatives
342 triplets.push_back(Triplet(row, baseColumn - 4, 6.0 * pointX));
343 triplets.push_back(Triplet(row, baseColumn - 3, 2.0));
344 triplets.push_back(Triplet(row, baseColumn + 0, -6.0 * pointX));
345 triplets.push_back(Triplet(row, baseColumn + 1, -2.0));
346 b(row) = 0.0;
347 ++row;
348 }
349 }
350 // Solve
351 A.setFromTriplets(triplets.begin(), triplets.end());
352 Eigen::SparseLU<Matrix> solver(A);
353 Vector x = solver.solve(b);
354 // Fill coefficients
355 for (qint32 i = 0; i < intervals; ++i) {
356 row = i * 4;
357 m_coefficients.append({x(row), x(row + 1), x(row + 2), x(row + 3)});
358 }
359 }
360
364 T getValue(T x) const {
366 // Find the interval for the given x value
367 int interval;
368 for (interval = 0; interval < m_coefficients.size() - 1; ++interval) {
369 if (x < m_points[interval + 1].x()) {
370 break;
371 }
372 }
373 // Evaluate
374 const T xSquared = x * x;
375 const T xCubed = xSquared * x;
376 const Coefficients& coefficients = m_coefficients[interval];
377 return coefficients.a * xCubed + coefficients.b * xSquared +
378 coefficients.c * x + coefficients.d;
379 }
380
381private:
386 {
387 T a, b, c, d;
388 };
389
392};
393
394#endif
KisCubicSpline(const QList< T_point > &a)
void createSpline(const QList< T_point > &a)
QList< T_point > m_points
QList< Coefficients > m_coefficients
KisLegacyCubicSpline(const QList< T_point > &a)
int findRegion(T x, T &x0) const
void createSpline(const QList< T_point > &a)
static QVector< T > calculate(QList< T > &a, QList< T > &b, QList< T > &c, QList< T > &f)
#define KIS_SAFE_ASSERT_RECOVER_RETURN_VALUE(cond, val)
Definition kis_assert.h:129
#define KIS_SAFE_ASSERT_RECOVER_RETURN(cond)
Definition kis_assert.h:128