diff --git a/addons/libs/strings.lua b/addons/libs/strings.lua index 604d18654..4697a2095 100644 --- a/addons/libs/strings.lua +++ b/addons/libs/strings.lua @@ -14,11 +14,14 @@ local string = require('string') _libs.strings = string +_raw = _raw or {} +_raw.error = _raw.error or error + _meta = _meta or {} debug.setmetatable('', { __index = function(str, k) - return string[k] or type(k) == 'number' and string.sub(str, k, k) or (_raw and _raw.error or error)('"%s" is not defined for strings':format(tostring(k)), 2) + return string[k] or type(k) == 'number' and string.sub(str, k, k) or _raw.error('"%s" is not defined for strings':format(tostring(k)), 2) end, __unm = functions.negate .. functions.equals, __unp = functions.equals, @@ -130,54 +133,659 @@ end -- Returns an iterator, that goes over every character of the string. Handles Japanese text as well as special characters and auto-translate. do - local process = function(str, fn) - local index = 1 - return function() - if index > #str then + local adjust_from = function(str, index) + return + not index and 1 or + index < 0 and #str - index + 1 or + index == 0 and 1 or + index + end + + local adjust_to = function(str, index) + return + not index and #str or + index < 0 and #str - index + 1 or + index + end + + do + local process = function(str, from, to, fn) + local index = from + return function() + if index > to then + return nil + end + + local length = fn(str:byte(index, index)) + if length == nil then + _raw.error('Invalid code point') + end + + index = index + length + return str:sub(index - length, index - 1) + end + end + + local iterators = { + [string.encoding.ascii] = function(str, from, to) + return str:sub(from, to):gmatch('.') + end, + [string.encoding.utf8] = function(str, from, to) + return process(str, from, to, function(byte) + return + byte < 0x80 and 1 or + byte < 0xE0 and 2 or + byte < 0xF0 and 3 or + byte < 0xF8 and 4 + end) + end, + [string.encoding.shift_jis] = function(str, from, to) + return process(str, from, to, function(byte) + return + (byte < 0x80 or byte >= 0xA1 and byte <= 0xDF) and 1 or + (byte >= 0x80 and byte <= 0x9F or byte >= 0xE0 and byte <= 0xEF or byte >= 0xFA and byte <= 0xFC) and 2 or + byte == 0xFD and 6 + end) + end, + [string.encoding.binary] = function(str, from, to) + return str:sub(from, to):gmatch('.') + end, + } + + function string.it(str, encoding, from, to) + if type(encoding) ~= 'table' then + encoding, from, to = string.encoding.ascii, encoding, from + end + return iterators[encoding](str, from or 1, to or #str) + end + end + + local all = function() return true end + local lower = function(b) return b >= 0x61 and b <= 0x7A end + local upper = function(b) return b >= 0x41 and b <= 0x5A end + local control = function(b) return b == 0x7F or b < 0x20 end + local punctuation = function(b) return b >= 0x21 and b <= 0x2F or b >= 0x3A and b <= 0x40 or b >= 0x5B and b <= 0x60 or b >= 0x7B and b <= 0x7E end + local letter = function(b) return lower(b) or upper(b) end + local digit = function(b) return b >= 0x30 and b <= 0x39 end + local space = function(b) return b == 0x20 or b == 0x0A or b == 0x07 or b == 0x09 end + local hex = function(b) return b >= 0x30 and b <= 0x39 or b >= 0x41 and b <= 0x46 or b >= 0x63 and b <= 0x66 end + local zero = function(b) return b == 0x00 end + local utf8_letter = function(b) return not control(b) and not digit(b) and not punctuation(b) and not space(b) end + local shift_jis_letter = function(b) return letter(b) or b >= 0xA1 and b <= 0xDF or b >= 0x823F and b <= 0x8491 or b >= 0x889F and b <= 0x9872 or b >= 0x989F and b <= 0xEAA4 end + local none = function() return false end + local classes = { + [string.encoding.ascii] = { + a = letter, + c = control, + d = digit, + l = lower, + p = punctuation, + u = upper, + s = space, + w = function(b) return letter(b) or digit(b) end, + x = hex, + z = zero, + }, + [string.encoding.utf8] = { + a = utf8_letter, + c = control, + d = digit, + l = lower, + p = punctuation, + u = upper, + s = space, + w = function(b) return utf8_letter(b) or digit(b) end, + x = hex, + z = zero, + }, + [string.encoding.shift_jis] = { + a = shift_jis_letter, + c = control, + d = digit, + l = lower, + p = function(b) return punctuation(b) or b >= 0x8140 and b <= 0x81FC or b >= 0x849F and b <= 0x84BE or b >= 0x8740 and b <= 0x849C end, + u = upper, + s = space, + w = function(b) return shift_jis_letter(b) or digit(b) end, + x = hex, + z = zero, + }, + [string.encoding.binary] = { + a = none, + c = none, + d = none, + l = none, + p = none, + u = none, + s = none, + w = none, + x = none, + z = zero, + }, + } + + do + local rawfind = string.find + + local findplain = function(str, pattern, encoding, from, to) + local offset = #pattern - 1 + local index = from + local search = pattern:it(encoding):pack() + local length = #search + for c in str:it(encoding, from, to - offset + 1) do + local position = 0 + for check in str:it(encoding, index + #c, index + offset) do + position = position + 1 + if check ~= search[position] then + break + end + if position == length then + return index, index + offset + end + end + index = index + #c + end + return nil, nil + end + + local types = { + capture = {}, + fixed = {}, + match = {}, + boundary = {}, + balanced = {}, + } + + local bytes = function(c) + local value = 0 + for i = 1, #c do + value = value * 0x100 + string.byte(c, i, i) + end + return value + end + + local parse + parse = function(iterator, class_lookup, level) + level = level or 0 + local pattern = {} + + local count = 0 + local last = false + local previous_char = nil + for c in iterator do + if last then + return '$ only valid at the end' + end + + local previous = pattern[count] + count = count + 1 + if (c == '-' or c == '?' or c == '*' or c == '+') and previous ~= nil and (previous.type == types.match or previous.type == types.fixed) and not previous.counter then + if previous.type == types.fixed then + if previous.value == previous_char then + count = count - 1 + else + previous.value = previous.value:sub(1, #previous.value - #previous_char) + end + local compare = bytes(c) + pattern[count] = { + type = types.match, + value = function(b) return b == compare end, + } + else + count = count - 1 + end + pattern[count].counter = c + elseif c == '(' then + pattern[count] = { + type = types.capture, + value = parse(iterator, class_lookup, level + 1), + } + elseif c == '.' then + pattern[count] = { + type = types.match, + value = all, + } + elseif c == '[' then + local set = {} + local next = iterator() + while next ~= ']' do + local add = next == '%' and iterator() or next + if add == nil then + return 'missing \']\'' + end + set[bytes(add)] = true + next = iterator() + end + + pattern[count] = { + type = types.match, + value = function(b) return set[b] end, + } + elseif c == '^' then + if count > 1 then + return '^ only valid at the start' + end + + pattern[count] = { + type = types.boundary, + value = '^', + } + elseif c == '$' then + last = true + + pattern[count] = { + type = types.boundary, + value = '$', + } + else + local single = nil + if c == '%' then + local next = iterator() + if next == nil then + return 'ends with \'%\'' + end + + if next == 'b' then + local open = iterator() + if open == nil then + return 'unbalanced pattern' + end + local close = iterator() + if close == nil then + return 'unbalanced pattern' + end + pattern[count] = { + type = types.balanced, + value = {open, close}, + } + else + local fn = class_lookup[next] + if fn ~= nil then + pattern[count] = { + type = types.match, + value = fn, + } + else + single = next + end + end + elseif c == ')' and level > 0 then + break + else + single = c + end + + if single ~= nil then + if previous ~= nil and previous.type == types.fixed then + previous.value = previous.value .. single + count = count - 1 + else + pattern[count] = { + type = types.fixed, + value = single, + } + end + end + end + + previous_char = c + end + + return pattern + end + + local match + match = function(iterate, pos, pattern, index) + if index == #pattern then + return pos - 1 + end + index = index + 1 + + local current = pattern[index] + local type = current.type + local value = current.value + local iterator = iterate(pos) + + if type == types.capture then + local inner_pattern = value + local inner_to, inner_captures = match(iterate, pos, inner_pattern, 0) + if not inner_to then + return nil + end + + local to, captures = match(iterate, inner_to + 1, pattern, index) + if not to then + return nil + end + + local allcaptures = {{pos, inner_to}, unpack(inner_captures or {})} + local count = #allcaptures + for i = 1, #(captures or {}) do + count = count + 1 + allcaptures[count] = captures[i] + end + + return to, allcaptures + + elseif type == types.fixed then + local compare = value + local compare_length = #compare + local length = 0 + while length < compare_length do + local char = iterator() + if char == nil or char ~= compare:sub(length + 1, length + #char) then + return nil + end + + pos = pos + #char + length = length + #char + end + + return match(iterate, pos, pattern, index) + + elseif type == types.match then + local check = function(c) return c ~= nil and value(bytes(c)) end + local counter = current.counter + if not counter then + local char = iterator() + if not check(char) then + return nil + end + + return match(iterate, pos + #char, pattern, index) + + elseif counter == '-' then + local to, captures = match(iterate, pos, pattern, index) + while to == nil do + local char = iterator() + if char == nil or not check(char) then + return nil + end + + to, captures = match(iterate, pos + #char, pattern, index) + end + + return to, captures + + else + if counter == '?' then + local char = iterator() + if check(char) then + local to, captures = match(iterate, pos + #char, pattern, index) + if to then + return to, captures + end + end + + return match(iterate, pos, pattern, index) + + elseif counter == '*' then + local char = iterator() + + local positions = {} + local count = 0 + while (check(char)) do + local prev = positions[count] or pos + count = count + 1 + positions[count] = prev + #char + char = iterator() + end + + for i = count, 1, -1 do + local to, captures = match(iterate, positions[i], pattern, index) + if to then + return to, captures + end + end + + return match(iterate, pos, pattern, index) + + elseif counter == '+' then + local char = iterator() + + local positions = {} + local count = 0 + while (check(char)) do + local prev = positions[count] or pos + count = count + 1 + positions[count] = prev + #char + char = iterator() + end + + for i = count, 1, -1 do + local to, captures = match(iterate, positions[i], pattern, index) + if to then + return to, captures + end + end + + return nil + + end + end + + elseif type == types.balanced then + local char = iterator() + local open, close = unpack(value) + if char ~= open then + return nil + end + + local count = 1 + repeat + pos = pos + #char + char = iterator() + if char == nil then + return nil + end + + if char == open then + count = count + 1 + elseif char == close then + count = count - 1 + end + until char == close and count == 0 + + return match(iterate, pos + #close, pattern, index) + + elseif type == types.boundary then + local char = iterator() + if char ~= nil then + return nil + end + return match(iterate, pos, pattern, index) + + end + end + + local pack = function(from, iterate, pattern, offset) + local to, captures = match(iterate, from, pattern, offset or 0) + if to == nil then return nil end - local length = fn(str:byte(index, index)) - if length == nil then - error('Invalid code point') + return { + from = from, + to = to, + captures = captures or {}, + } + end + + local findpattern = function(iterate, length, pattern) + local first = pattern[1] + local matches + if first.type == types.boundary and first.value == '^' then + matches = pack(1, iterate, pattern, 1) + else + for i = 1, length do + matches = pack(i, iterate, pattern) + if matches then + break + end + end + end + + return matches + end + + local encoding_mt = { + __index = function(encoding_cache, rawpattern) + local pattern = parse(rawpattern:it(encoding_cache.encoding), classes[encoding_cache.encoding]) + rawset(encoding_cache, rawpattern, pattern) + return pattern + end, + } + + local pattern_cache = setmetatable({}, { + __index = function(pattern_cache, encoding) + local encoding_cache = setmetatable({ encoding = encoding }, encoding_mt) + rawset(pattern_cache, encoding, encoding_cache) + return encoding_cache + end, + }) + + local find = function(plain, str, pattern, encoding, from, to) + if plain then + return findplain(str, pattern, encoding, from, to) + else + local offset = from - 1 + local matches = findpattern(function(pos) return str:it(encoding, pos + offset, to) end, to - offset, pattern_cache[encoding][pattern]) + if not matches then + return nil + end + + for i = 1, #matches.captures do + local first, last = unpack(matches.captures[i]) + matches.captures[i] = str:sub(first + offset, last + offset) + end + + return matches.from + offset, matches.to + offset, unpack(matches.captures) + end + end + + function string.find(str, pattern, encoding, from, to, plain) + if type(encoding) == 'number' then + encoding, from, to, plain = type(to) == 'table' and to or nil, encoding, type(from) == 'number' and from or plain, type(from) == 'boolean' and from or nil + end + encoding = encoding or string.encoding.ascii + + if encoding == string.encoding.ascii and to == nil then + return rawfind(str, pattern, from, plain) end - index = index + length - return str:sub(index - length, index - 1) + return find(plain, str, pattern, encoding, adjust_from(str, from), adjust_to(str, to)) end end - local iterators = { - [string.encoding.ascii] = function(str) - return str:gmatch('.') - end, - [string.encoding.utf8] = function(str) - return process(str, function(byte) - return - byte < 0x80 and 1 or - byte < 0xE0 and 2 or - byte < 0xF0 and 3 or - byte < 0xF8 and 4 - end) - end, - [string.encoding.shift_jis] = function(str) - return process(str, function(byte) - return - (byte < 0x80 or byte >= 0xA1 and byte <= 0xDF) and 1 or - (byte >= 0x80 and byte <= 0x9F or byte >= 0xE0 and byte <= 0xEF or byte >= 0xFA and byte <= 0xFC) and 2 or - byte == 0xFD and 6 - end) - end, - [string.encoding.binary] = function(str) - return str:gmatch('.') - end, - } + do + local rawmatch = string.match + + local process = function(str, first, last, ...) + if not first then + return nil + end + + if select('#', ...) == 0 then + return str:sub(first, last) + end + + return ... + end + + function string.match(str, pattern, encoding, from, to) + if (type(encoding) == 'number') then + encoding, from, to = type(from) == 'table' and from or nil, encoding, type(from) == 'number' and from or to + end + encoding = encoding or string.encoding.ascii + + if encoding == string.encoding.ascii and to == nil then + return rawmatch(str, pattern, from) + end - function string.it(str, encoding) - return iterators[encoding or string.encoding.ascii](str) + return process(str, string.find(str, pattern, adjust_from(str, from), false, encoding, adjust_to(str, to))) + end + end + + do + local rawgmatch = string.gmatch + + function string.gmatch(str, pattern, encoding, from, to) + if (type(encoding) == 'number') then + encoding, from, to = type(from) == 'table' and from or nil, encoding, type(from) == 'number' and from or to + end + encoding = encoding or string.encoding.ascii + + if encoding == string.encoding.ascii and to == nil then + return rawgmatch(str, pattern, from) + end + + local pos = adjust_from(str, from) + local process = function(first, last, ...) + if not first then + return nil + end + + if last >= pos then + pos = last + 1 + else + local char = str:it(encoding, pos)() + pos = pos + #char + end + + if select('#', ...) == 0 then + return str:sub(first, last) + end + + return ... + end + + return function() + return process(string.find(str, pattern, encoding, pos, adjust_to(str, to))) + end + end + end + + do + local rawgsub = string.gsub + + function string.gsub(str, pattern, repl, n, encoding, from, to) + if type(n) == 'table' then + n, encoding, from, to = nil, n, encoding, from + elseif type(encoding) == 'number' then + encoding, from, to = nil, encoding, from + end + encoding = encoding or string.encoding.ascii + + if encoding == string.encoding.ascii and to == nil then + return rawgsub(str, pattern, from) + end + + local repltype = type(repl) + repl = + repltype == 'function' and repl or + repltype == 'table' and function(match) return repl[match] end or + function() return repl end + + local pos = adjust_from(str, from) + local fragments = {} + local count = 0 + repeat + local first, last = string.find(str, pattern, encoding, pos, adjust_to(str, to)) + if first then + count = count + 1 + fragments[count] = str:sub(pos, first - 1) + count = count + 1 + fragments[count] = repl(str:sub(first, last)) + end + pos = last + 1 + until first == nil or count == n + + return table.concat(fragments) .. str:sub(pos) + end end end + -- Removes leading and trailing whitespaces and similar characters (tabs, newlines, etc.). function string.trim(str) return str:match('^%s*(.-)%s*$') @@ -284,7 +892,7 @@ end function string.parse_hex(str) local interpreted_string = str:gsub('0x', ''):gsub('[^%w]', '') if #interpreted_string % 2 ~= 0 then - (_raw and _raw.error or error)('Invalid input string length', 2) + _raw.error('Invalid input string length', 2) end return (interpreted_string:gsub('%w%w', hex_r)) @@ -296,7 +904,7 @@ end function string.parse_binary(str) local interpreted_string = str:gsub('0b', ''):gsub('[^01]', '') if #interpreted_string % 8 ~= 0 then - (_raw and _raw.error or error)('Invalid input string length', 2) + _raw.error('Invalid input string length', 2) end return (interpreted_string:gsub(binary_pattern, binary_r))