|
import sympy |
|
import re |
|
from sympy import matrix_symbols, simplify, factor, expand, apart, expand_trig |
|
from antlr4 import InputStream, CommonTokenStream |
|
from antlr4.error.ErrorListener import ErrorListener |
|
|
|
try: |
|
from gen.PSParser import PSParser |
|
from gen.PSLexer import PSLexer |
|
from gen.PSListener import PSListener |
|
except Exception: |
|
from .gen.PSParser import PSParser |
|
from .gen.PSLexer import PSLexer |
|
from .gen.PSListener import PSListener |
|
|
|
from sympy.printing.str import StrPrinter |
|
|
|
from sympy.parsing.sympy_parser import parse_expr |
|
|
|
import hashlib |
|
|
|
is_real = None |
|
|
|
frac_type = r'\frac' |
|
|
|
variances = {} |
|
var = {} |
|
|
|
VARIABLE_VALUES = {} |
|
|
|
|
|
def set_real(value): |
|
global is_real |
|
is_real = value |
|
|
|
|
|
def set_variances(vars): |
|
global variances |
|
variances = vars |
|
global var |
|
var = {} |
|
for variance in vars: |
|
var[str(variance)] = vars[variance] |
|
|
|
|
|
def latex2sympy(sympy: str, variable_values={}): |
|
|
|
global frac_type |
|
if sympy.find(r'\frac') != -1: |
|
frac_type = r'\frac' |
|
if sympy.find(r'\dfrac') != -1: |
|
frac_type = r'\dfrac' |
|
if sympy.find(r'\tfrac') != -1: |
|
frac_type = r'\tfrac' |
|
sympy = sympy.replace(r'\dfrac', r'\frac') |
|
sympy = sympy.replace(r'\tfrac', r'\frac') |
|
|
|
sympy = sympy.replace(r'\mathrm{T}', 'T', -1) |
|
|
|
sympy = sympy.replace(r'\mathrm{d}', 'd', -1).replace(r'{\rm d}', 'd', -1) |
|
|
|
sympy = sympy.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}', -1).replace(r'\end{matrix}\right]', r'\end{bmatrix}', -1) |
|
|
|
sympy = re.sub(r"\(([a-zA-Z0-9+\-*/\\ ]+?)\)_{([a-zA-Z0-9+\-*/\\ ]+?)}", r"\\frac{(\1)!}{((\1)-(\2))!}", sympy) |
|
|
|
sympy = sympy.replace(r'\displaystyle', ' ', -1) |
|
|
|
sympy = sympy.replace(r'\quad', ' ', -1).replace(r'\qquad', ' ', -1).replace(r'~', ' ', -1).replace(r'\,', ' ', -1) |
|
|
|
sympy = sympy.replace(r'$', ' ', -1) |
|
|
|
|
|
global VARIABLE_VALUES |
|
if len(variable_values) > 0: |
|
VARIABLE_VALUES = variable_values |
|
else: |
|
VARIABLE_VALUES = {} |
|
|
|
|
|
matherror = MathErrorListener(sympy) |
|
|
|
|
|
stream = InputStream(sympy) |
|
lex = PSLexer(stream) |
|
lex.removeErrorListeners() |
|
lex.addErrorListener(matherror) |
|
|
|
tokens = CommonTokenStream(lex) |
|
parser = PSParser(tokens) |
|
|
|
|
|
parser.removeErrorListeners() |
|
parser.addErrorListener(matherror) |
|
|
|
|
|
return_data = None |
|
math = parser.math() |
|
|
|
|
|
if math.relation_list(): |
|
return_data = [] |
|
|
|
|
|
relation_list = math.relation_list().relation_list_content() |
|
for list_item in relation_list.relation(): |
|
expr = convert_relation(list_item) |
|
return_data.append(expr) |
|
|
|
|
|
else: |
|
relation = math.relation() |
|
return_data = convert_relation(relation) |
|
|
|
return return_data |
|
|
|
|
|
class MathErrorListener(ErrorListener): |
|
def __init__(self, src): |
|
super(ErrorListener, self).__init__() |
|
self.src = src |
|
|
|
def syntaxError(self, recog, symbol, line, col, msg, e): |
|
fmt = "%s\n%s\n%s" |
|
marker = "~" * col + "^" |
|
|
|
if msg.startswith("missing"): |
|
err = fmt % (msg, self.src, marker) |
|
elif msg.startswith("no viable"): |
|
err = fmt % ("I expected something else here", self.src, marker) |
|
elif msg.startswith("mismatched"): |
|
names = PSParser.literalNames |
|
expected = [names[i] for i in e.getExpectedTokens() if i < len(names)] |
|
if len(expected) < 10: |
|
expected = " ".join(expected) |
|
err = (fmt % ("I expected one of these: " + expected, |
|
self.src, marker)) |
|
else: |
|
err = (fmt % ("I expected something else here", self.src, marker)) |
|
else: |
|
err = fmt % ("I don't understand this", self.src, marker) |
|
raise Exception(err) |
|
|
|
|
|
def convert_relation(rel): |
|
if rel.expr(): |
|
return convert_expr(rel.expr()) |
|
|
|
lh = convert_relation(rel.relation(0)) |
|
rh = convert_relation(rel.relation(1)) |
|
if rel.LT(): |
|
return sympy.StrictLessThan(lh, rh, evaluate=False) |
|
elif rel.LTE(): |
|
return sympy.LessThan(lh, rh, evaluate=False) |
|
elif rel.GT(): |
|
return sympy.StrictGreaterThan(lh, rh, evaluate=False) |
|
elif rel.GTE(): |
|
return sympy.GreaterThan(lh, rh, evaluate=False) |
|
elif rel.EQUAL(): |
|
return sympy.Eq(lh, rh, evaluate=False) |
|
elif rel.ASSIGNMENT(): |
|
|
|
if lh.is_Symbol: |
|
|
|
variances[lh] = rh |
|
var[str(lh)] = rh |
|
return rh |
|
else: |
|
|
|
equation = lh - rh |
|
syms = equation.atoms(sympy.Symbol) |
|
if len(syms) > 0: |
|
|
|
result = [] |
|
for sym in syms: |
|
values = sympy.solve(equation, sym) |
|
for value in values: |
|
result.append(sympy.Eq(sym, value, evaluate=False)) |
|
return result |
|
else: |
|
return sympy.Eq(lh, rh, evaluate=False) |
|
elif rel.IN(): |
|
|
|
if hasattr(rh, 'is_Pow') and rh.is_Pow and hasattr(rh.exp, 'is_Mul'): |
|
n = rh.exp.args[0] |
|
m = rh.exp.args[1] |
|
if n in variances: |
|
n = variances[n] |
|
if m in variances: |
|
m = variances[m] |
|
rh = sympy.MatrixSymbol(lh, n, m) |
|
variances[lh] = rh |
|
var[str(lh)] = rh |
|
else: |
|
raise Exception("Don't support this form of definition of matrix symbol.") |
|
return lh |
|
elif rel.UNEQUAL(): |
|
return sympy.Ne(lh, rh, evaluate=False) |
|
|
|
|
|
def convert_expr(expr): |
|
if expr.additive(): |
|
return convert_add(expr.additive()) |
|
|
|
|
|
def convert_elementary_transform(matrix, transform): |
|
if transform.transform_scale(): |
|
transform_scale = transform.transform_scale() |
|
transform_atom = transform_scale.transform_atom() |
|
k = None |
|
num = int(transform_atom.NUMBER().getText()) - 1 |
|
if transform_scale.expr(): |
|
k = convert_expr(transform_scale.expr()) |
|
elif transform_scale.group(): |
|
k = convert_expr(transform_scale.group().expr()) |
|
elif transform_scale.SUB(): |
|
k = -1 |
|
else: |
|
k = 1 |
|
if transform_atom.LETTER_NO_E().getText() == 'r': |
|
matrix = matrix.elementary_row_op(op='n->kn', row=num, k=k) |
|
elif transform_atom.LETTER_NO_E().getText() == 'c': |
|
matrix = matrix.elementary_col_op(op='n->kn', col=num, k=k) |
|
else: |
|
raise Exception('Row and col don\'s match') |
|
|
|
elif transform.transform_swap(): |
|
first_atom = transform.transform_swap().transform_atom()[0] |
|
second_atom = transform.transform_swap().transform_atom()[1] |
|
first_num = int(first_atom.NUMBER().getText()) - 1 |
|
second_num = int(second_atom.NUMBER().getText()) - 1 |
|
if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText(): |
|
raise Exception('Row and col don\'s match') |
|
elif first_atom.LETTER_NO_E().getText() == 'r': |
|
matrix = matrix.elementary_row_op(op='n<->m', row1=first_num, row2=second_num) |
|
elif first_atom.LETTER_NO_E().getText() == 'c': |
|
matrix = matrix.elementary_col_op(op='n<->m', col1=first_num, col2=second_num) |
|
else: |
|
raise Exception('Row and col don\'s match') |
|
|
|
elif transform.transform_assignment(): |
|
first_atom = transform.transform_assignment().transform_atom() |
|
second_atom = transform.transform_assignment().transform_scale().transform_atom() |
|
transform_scale = transform.transform_assignment().transform_scale() |
|
k = None |
|
if transform_scale.expr(): |
|
k = convert_expr(transform_scale.expr()) |
|
elif transform_scale.group(): |
|
k = convert_expr(transform_scale.group().expr()) |
|
elif transform_scale.SUB(): |
|
k = -1 |
|
else: |
|
k = 1 |
|
first_num = int(first_atom.NUMBER().getText()) - 1 |
|
second_num = int(second_atom.NUMBER().getText()) - 1 |
|
if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText(): |
|
raise Exception('Row and col don\'s match') |
|
elif first_atom.LETTER_NO_E().getText() == 'r': |
|
matrix = matrix.elementary_row_op(op='n->n+km', k=k, row1=first_num, row2=second_num) |
|
elif first_atom.LETTER_NO_E().getText() == 'c': |
|
matrix = matrix.elementary_col_op(op='n->n+km', k=k, col1=first_num, col2=second_num) |
|
else: |
|
raise Exception('Row and col don\'s match') |
|
|
|
return matrix |
|
|
|
|
|
def convert_matrix(matrix): |
|
|
|
row = matrix.matrix_row() |
|
tmp = [] |
|
rows = 0 |
|
mat = None |
|
|
|
for r in row: |
|
tmp.append([]) |
|
for expr in r.expr(): |
|
tmp[rows].append(convert_expr(expr)) |
|
rows = rows + 1 |
|
|
|
mat = sympy.Matrix(tmp) |
|
|
|
if hasattr(matrix, 'MATRIX_XRIGHTARROW') and matrix.MATRIX_XRIGHTARROW(): |
|
transforms_list = matrix.elementary_transforms() |
|
if len(transforms_list) == 1: |
|
for transform in transforms_list[0].elementary_transform(): |
|
mat = convert_elementary_transform(mat, transform) |
|
elif len(transforms_list) == 2: |
|
|
|
for transform in transforms_list[1].elementary_transform(): |
|
mat = convert_elementary_transform(mat, transform) |
|
|
|
for transform in transforms_list[0].elementary_transform(): |
|
mat = convert_elementary_transform(mat, transform) |
|
|
|
return mat |
|
|
|
|
|
def add_flat(lh, rh): |
|
if hasattr(lh, 'is_Add') and lh.is_Add or hasattr(rh, 'is_Add') and rh.is_Add: |
|
args = [] |
|
if hasattr(lh, 'is_Add') and lh.is_Add: |
|
args += list(lh.args) |
|
else: |
|
args += [lh] |
|
if hasattr(rh, 'is_Add') and rh.is_Add: |
|
args = args + list(rh.args) |
|
else: |
|
args += [rh] |
|
return sympy.Add(*args, evaluate=False) |
|
else: |
|
return sympy.Add(lh, rh, evaluate=False) |
|
|
|
|
|
def mat_add_flat(lh, rh): |
|
if hasattr(lh, 'is_MatAdd') and lh.is_MatAdd or hasattr(rh, 'is_MatAdd') and rh.is_MatAdd: |
|
args = [] |
|
if hasattr(lh, 'is_MatAdd') and lh.is_MatAdd: |
|
args += list(lh.args) |
|
else: |
|
args += [lh] |
|
if hasattr(rh, 'is_MatAdd') and rh.is_MatAdd: |
|
args = args + list(rh.args) |
|
else: |
|
args += [rh] |
|
return sympy.MatAdd(*[arg.doit() for arg in args], evaluate=False) |
|
else: |
|
return sympy.MatAdd(lh.doit(), rh.doit(), evaluate=False) |
|
|
|
|
|
def mul_flat(lh, rh): |
|
if hasattr(lh, 'is_Mul') and lh.is_Mul or hasattr(rh, 'is_Mul') and rh.is_Mul: |
|
args = [] |
|
if hasattr(lh, 'is_Mul') and lh.is_Mul: |
|
args += list(lh.args) |
|
else: |
|
args += [lh] |
|
if hasattr(rh, 'is_Mul') and rh.is_Mul: |
|
args = args + list(rh.args) |
|
else: |
|
args += [rh] |
|
return sympy.Mul(*args, evaluate=False) |
|
else: |
|
return sympy.Mul(lh, rh, evaluate=False) |
|
|
|
|
|
def mat_mul_flat(lh, rh): |
|
if hasattr(lh, 'is_MatMul') and lh.is_MatMul or hasattr(rh, 'is_MatMul') and rh.is_MatMul: |
|
args = [] |
|
if hasattr(lh, 'is_MatMul') and lh.is_MatMul: |
|
args += list(lh.args) |
|
else: |
|
args += [lh] |
|
if hasattr(rh, 'is_MatMul') and rh.is_MatMul: |
|
args = args + list(rh.args) |
|
else: |
|
args += [rh] |
|
return sympy.MatMul(*[arg.doit() for arg in args], evaluate=False) |
|
else: |
|
if hasattr(lh, 'doit') and hasattr(rh, 'doit'): |
|
return sympy.MatMul(lh.doit(), rh.doit(), evaluate=False) |
|
elif hasattr(lh, 'doit') and not hasattr(rh, 'doit'): |
|
return sympy.MatMul(lh.doit(), rh, evaluate=False) |
|
elif not hasattr(lh, 'doit') and hasattr(rh, 'doit'): |
|
return sympy.MatMul(lh, rh.doit(), evaluate=False) |
|
else: |
|
return sympy.MatMul(lh, rh, evaluate=False) |
|
|
|
|
|
def convert_add(add): |
|
if add.ADD(): |
|
lh = convert_add(add.additive(0)) |
|
rh = convert_add(add.additive(1)) |
|
|
|
if lh.is_Matrix or rh.is_Matrix: |
|
return mat_add_flat(lh, rh) |
|
else: |
|
return add_flat(lh, rh) |
|
elif add.SUB(): |
|
lh = convert_add(add.additive(0)) |
|
rh = convert_add(add.additive(1)) |
|
|
|
if lh.is_Matrix or rh.is_Matrix: |
|
return mat_add_flat(lh, mat_mul_flat(-1, rh)) |
|
else: |
|
|
|
|
|
if not rh.is_Matrix and rh.func.is_Number: |
|
rh = -rh |
|
else: |
|
rh = mul_flat(-1, rh) |
|
return add_flat(lh, rh) |
|
else: |
|
return convert_mp(add.mp()) |
|
|
|
|
|
def convert_mp(mp): |
|
if hasattr(mp, 'mp'): |
|
mp_left = mp.mp(0) |
|
mp_right = mp.mp(1) |
|
else: |
|
mp_left = mp.mp_nofunc(0) |
|
mp_right = mp.mp_nofunc(1) |
|
|
|
if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT(): |
|
lh = convert_mp(mp_left) |
|
rh = convert_mp(mp_right) |
|
|
|
if lh.is_Matrix or rh.is_Matrix: |
|
return mat_mul_flat(lh, rh) |
|
else: |
|
return mul_flat(lh, rh) |
|
elif mp.DIV() or mp.CMD_DIV() or mp.COLON(): |
|
lh = convert_mp(mp_left) |
|
rh = convert_mp(mp_right) |
|
if lh.is_Matrix or rh.is_Matrix: |
|
return sympy.MatMul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False) |
|
else: |
|
return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False) |
|
elif mp.CMD_MOD(): |
|
lh = convert_mp(mp_left) |
|
rh = convert_mp(mp_right) |
|
if rh.is_Matrix: |
|
raise Exception("Cannot perform modulo operation with a matrix as an operand") |
|
else: |
|
return sympy.Mod(lh, rh, evaluate=False) |
|
else: |
|
if hasattr(mp, 'unary'): |
|
return convert_unary(mp.unary()) |
|
else: |
|
return convert_unary(mp.unary_nofunc()) |
|
|
|
|
|
def convert_unary(unary): |
|
if hasattr(unary, 'unary'): |
|
nested_unary = unary.unary() |
|
else: |
|
nested_unary = unary.unary_nofunc() |
|
if hasattr(unary, 'postfix_nofunc'): |
|
first = unary.postfix() |
|
tail = unary.postfix_nofunc() |
|
postfix = [first] + tail |
|
else: |
|
postfix = unary.postfix() |
|
|
|
if unary.ADD(): |
|
return convert_unary(nested_unary) |
|
elif unary.SUB(): |
|
tmp_convert_nested_unary = convert_unary(nested_unary) |
|
if tmp_convert_nested_unary.is_Matrix: |
|
return mat_mul_flat(-1, tmp_convert_nested_unary, evaluate=False) |
|
else: |
|
if tmp_convert_nested_unary.func.is_Number: |
|
return -tmp_convert_nested_unary |
|
else: |
|
return mul_flat(-1, tmp_convert_nested_unary) |
|
elif postfix: |
|
return convert_postfix_list(postfix) |
|
|
|
|
|
def convert_postfix_list(arr, i=0): |
|
if i >= len(arr): |
|
raise Exception("Index out of bounds") |
|
|
|
res = convert_postfix(arr[i]) |
|
|
|
if isinstance(res, sympy.Expr) or isinstance(res, sympy.Matrix) or res is sympy.S.EmptySet: |
|
if i == len(arr) - 1: |
|
return res |
|
else: |
|
|
|
rh = convert_postfix_list(arr, i + 1) |
|
|
|
if res.is_Matrix or rh.is_Matrix: |
|
return mat_mul_flat(res, rh) |
|
else: |
|
return mul_flat(res, rh) |
|
elif isinstance(res, tuple) or isinstance(res, list) or isinstance(res, dict): |
|
return res |
|
else: |
|
wrt = res[0] |
|
if i == len(arr) - 1: |
|
raise Exception("Expected expression for derivative") |
|
else: |
|
expr = convert_postfix_list(arr, i + 1) |
|
return sympy.Derivative(expr, wrt) |
|
|
|
|
|
def do_subs(expr, at): |
|
if at.expr(): |
|
at_expr = convert_expr(at.expr()) |
|
syms = at_expr.atoms(sympy.Symbol) |
|
if len(syms) == 0: |
|
return expr |
|
elif len(syms) > 0: |
|
sym = next(iter(syms)) |
|
return expr.subs(sym, at_expr) |
|
elif at.equality(): |
|
lh = convert_expr(at.equality().expr(0)) |
|
rh = convert_expr(at.equality().expr(1)) |
|
return expr.subs(lh, rh) |
|
|
|
|
|
def convert_postfix(postfix): |
|
if hasattr(postfix, 'exp'): |
|
exp_nested = postfix.exp() |
|
else: |
|
exp_nested = postfix.exp_nofunc() |
|
|
|
exp = convert_exp(exp_nested) |
|
for op in postfix.postfix_op(): |
|
if op.BANG(): |
|
if isinstance(exp, list): |
|
raise Exception("Cannot apply postfix to derivative") |
|
exp = sympy.factorial(exp, evaluate=False) |
|
elif op.eval_at(): |
|
ev = op.eval_at() |
|
at_b = None |
|
at_a = None |
|
if ev.eval_at_sup(): |
|
at_b = do_subs(exp, ev.eval_at_sup()) |
|
if ev.eval_at_sub(): |
|
at_a = do_subs(exp, ev.eval_at_sub()) |
|
if at_b is not None and at_a is not None: |
|
exp = add_flat(at_b, mul_flat(at_a, -1)) |
|
elif at_b is not None: |
|
exp = at_b |
|
elif at_a is not None: |
|
exp = at_a |
|
elif op.transpose(): |
|
try: |
|
exp = exp.T |
|
except: |
|
try: |
|
exp = sympy.transpose(exp) |
|
except: |
|
pass |
|
pass |
|
|
|
return exp |
|
|
|
|
|
def convert_exp(exp): |
|
if hasattr(exp, 'exp'): |
|
exp_nested = exp.exp() |
|
else: |
|
exp_nested = exp.exp_nofunc() |
|
|
|
if exp_nested: |
|
base = convert_exp(exp_nested) |
|
if isinstance(base, list): |
|
raise Exception("Cannot raise derivative to power") |
|
if exp.atom(): |
|
exponent = convert_atom(exp.atom()) |
|
elif exp.expr(): |
|
exponent = convert_expr(exp.expr()) |
|
return sympy.Pow(base, exponent, evaluate=False) |
|
else: |
|
if hasattr(exp, 'comp'): |
|
return convert_comp(exp.comp()) |
|
else: |
|
return convert_comp(exp.comp_nofunc()) |
|
|
|
|
|
def convert_comp(comp): |
|
if comp.group(): |
|
return convert_expr(comp.group().expr()) |
|
elif comp.norm_group(): |
|
return convert_expr(comp.norm_group().expr()).norm() |
|
elif comp.abs_group(): |
|
return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False) |
|
elif comp.floor_group(): |
|
return handle_floor(convert_expr(comp.floor_group().expr())) |
|
elif comp.ceil_group(): |
|
return handle_ceil(convert_expr(comp.ceil_group().expr())) |
|
elif comp.atom(): |
|
return convert_atom(comp.atom()) |
|
elif comp.frac(): |
|
return convert_frac(comp.frac()) |
|
elif comp.binom(): |
|
return convert_binom(comp.binom()) |
|
elif comp.matrix(): |
|
return convert_matrix(comp.matrix()) |
|
elif comp.det(): |
|
|
|
return convert_matrix(comp.det()).subs(variances).det() |
|
elif comp.func(): |
|
return convert_func(comp.func()) |
|
|
|
|
|
def convert_atom(atom): |
|
if atom.atom_expr(): |
|
atom_expr = atom.atom_expr() |
|
|
|
|
|
atom_text = '' |
|
if atom_expr.LETTER_NO_E(): |
|
atom_text = atom_expr.LETTER_NO_E().getText() |
|
if atom_text == "I": |
|
return sympy.I |
|
elif atom_expr.GREEK_CMD(): |
|
atom_text = atom_expr.GREEK_CMD().getText()[1:].strip() |
|
elif atom_expr.OTHER_SYMBOL_CMD(): |
|
atom_text = atom_expr.OTHER_SYMBOL_CMD().getText().strip() |
|
elif atom_expr.accent(): |
|
atom_accent = atom_expr.accent() |
|
|
|
name = atom_accent.start.text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base = atom_accent.base.getText() |
|
|
|
atom_text = name + '{' + base + '}' |
|
|
|
|
|
subscript_text = '' |
|
if atom_expr.subexpr(): |
|
subexpr = atom_expr.subexpr() |
|
subscript = None |
|
if subexpr.expr(): |
|
subscript = subexpr.expr().getText().strip() |
|
elif subexpr.atom(): |
|
subscript = subexpr.atom().getText().strip() |
|
elif subexpr.args(): |
|
subscript = subexpr.args().getText().strip() |
|
subscript_inner_text = StrPrinter().doprint(subscript) |
|
if len(subscript_inner_text) > 1: |
|
subscript_text = '_{' + subscript_inner_text + '}' |
|
else: |
|
subscript_text = '_' + subscript_inner_text |
|
|
|
|
|
atom_symbol = sympy.Symbol(atom_text + subscript_text, real=is_real) |
|
|
|
matrix_symbol = None |
|
global var |
|
if atom_text + subscript_text in var: |
|
try: |
|
rh = var[atom_text + subscript_text] |
|
shape = sympy.shape(rh) |
|
matrix_symbol = sympy.MatrixSymbol(atom_text + subscript_text, shape[0], shape[1]) |
|
variances[matrix_symbol] = variances[atom_symbol] |
|
except: |
|
pass |
|
|
|
|
|
if atom_expr.supexpr(): |
|
supexpr = atom_expr.supexpr() |
|
func_pow = None |
|
if supexpr.expr(): |
|
func_pow = convert_expr(supexpr.expr()) |
|
else: |
|
func_pow = convert_atom(supexpr.atom()) |
|
return sympy.Pow(atom_symbol, func_pow, evaluate=False) |
|
|
|
return atom_symbol if not matrix_symbol else matrix_symbol |
|
elif atom.SYMBOL(): |
|
s = atom.SYMBOL().getText().replace("\\$", "").replace("\\%", "") |
|
if s == "\\infty": |
|
return sympy.oo |
|
elif s == '\\pi': |
|
return sympy.pi |
|
elif s == '\\emptyset': |
|
return sympy.S.EmptySet |
|
else: |
|
raise Exception("Unrecognized symbol") |
|
elif atom.NUMBER(): |
|
s = atom.NUMBER().getText().replace(",", "") |
|
try: |
|
sr = sympy.Rational(s) |
|
return sr |
|
except (TypeError, ValueError): |
|
return sympy.Number(s) |
|
elif atom.E_NOTATION(): |
|
s = atom.E_NOTATION().getText().replace(",", "") |
|
try: |
|
sr = sympy.Rational(s) |
|
return sr |
|
except (TypeError, ValueError): |
|
return sympy.Number(s) |
|
elif atom.DIFFERENTIAL(): |
|
var = get_differential_var(atom.DIFFERENTIAL()) |
|
return sympy.Symbol('d' + var.name, real=is_real) |
|
elif atom.mathit(): |
|
text = rule2text(atom.mathit().mathit_text()) |
|
return sympy.Symbol(text, real=is_real) |
|
elif atom.VARIABLE(): |
|
text = atom.VARIABLE().getText() |
|
is_percent = text.endswith("\\%") |
|
trim_amount = 3 if is_percent else 1 |
|
name = text[10:] |
|
name = name[0:len(name) - trim_amount] |
|
|
|
|
|
hash = hashlib.md5(name.encode()).hexdigest() |
|
symbol_name = name + hash |
|
|
|
|
|
if name in VARIABLE_VALUES: |
|
|
|
if isinstance(VARIABLE_VALUES[name], tuple(sympy.core.all_classes)): |
|
symbol = VARIABLE_VALUES[name] |
|
|
|
|
|
else: |
|
symbol = parse_expr(str(VARIABLE_VALUES[name])) |
|
else: |
|
symbol = sympy.Symbol(symbol_name, real=is_real) |
|
|
|
if is_percent: |
|
return sympy.Mul(symbol, sympy.Pow(100, -1, evaluate=False), evaluate=False) |
|
|
|
|
|
return symbol |
|
|
|
elif atom.PERCENT_NUMBER(): |
|
text = atom.PERCENT_NUMBER().getText().replace("\\%", "").replace(",", "") |
|
try: |
|
number = sympy.Rational(text) |
|
except (TypeError, ValueError): |
|
number = sympy.Number(text) |
|
percent = sympy.Rational(number, 100) |
|
return percent |
|
|
|
|
|
def rule2text(ctx): |
|
stream = ctx.start.getInputStream() |
|
|
|
startIdx = ctx.start.start |
|
|
|
stopIdx = ctx.stop.stop |
|
|
|
return stream.getText(startIdx, stopIdx) |
|
|
|
|
|
def convert_frac(frac): |
|
diff_op = False |
|
partial_op = False |
|
lower_itv = frac.lower.getSourceInterval() |
|
lower_itv_len = lower_itv[1] - lower_itv[0] + 1 |
|
if (frac.lower.start == frac.lower.stop and |
|
frac.lower.start.type == PSLexer.DIFFERENTIAL): |
|
wrt = get_differential_var_str(frac.lower.start.text) |
|
diff_op = True |
|
elif (lower_itv_len == 2 and |
|
frac.lower.start.type == PSLexer.SYMBOL and |
|
frac.lower.start.text == '\\partial' and |
|
(frac.lower.stop.type == PSLexer.LETTER_NO_E or frac.lower.stop.type == PSLexer.SYMBOL)): |
|
partial_op = True |
|
wrt = frac.lower.stop.text |
|
if frac.lower.stop.type == PSLexer.SYMBOL: |
|
wrt = wrt[1:] |
|
|
|
if diff_op or partial_op: |
|
wrt = sympy.Symbol(wrt, real=is_real) |
|
if (diff_op and frac.upper.start == frac.upper.stop and |
|
frac.upper.start.type == PSLexer.LETTER_NO_E and |
|
frac.upper.start.text == 'd'): |
|
return [wrt] |
|
elif (partial_op and frac.upper.start == frac.upper.stop and |
|
frac.upper.start.type == PSLexer.SYMBOL and |
|
frac.upper.start.text == '\\partial'): |
|
return [wrt] |
|
upper_text = rule2text(frac.upper) |
|
|
|
expr_top = None |
|
if diff_op and upper_text.startswith('d'): |
|
expr_top = latex2sympy(upper_text[1:]) |
|
elif partial_op and frac.upper.start.text == '\\partial': |
|
expr_top = latex2sympy(upper_text[len('\\partial'):]) |
|
if expr_top: |
|
return sympy.Derivative(expr_top, wrt) |
|
|
|
expr_top = convert_expr(frac.upper) |
|
expr_bot = convert_expr(frac.lower) |
|
if expr_top.is_Matrix or expr_bot.is_Matrix: |
|
return sympy.MatMul(expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False) |
|
else: |
|
return sympy.Mul(expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False) |
|
|
|
|
|
def convert_binom(binom): |
|
expr_top = convert_expr(binom.upper) |
|
expr_bot = convert_expr(binom.lower) |
|
return sympy.binomial(expr_top, expr_bot) |
|
|
|
|
|
def convert_func(func): |
|
if func.func_normal_single_arg(): |
|
if func.L_PAREN(): |
|
arg = convert_func_arg(func.func_single_arg()) |
|
else: |
|
arg = convert_func_arg(func.func_single_arg_noparens()) |
|
|
|
name = func.func_normal_single_arg().start.text[1:] |
|
|
|
|
|
if name in ["arcsin", "arccos", "arctan", "arccsc", "arcsec", |
|
"arccot"]: |
|
name = "a" + name[3:] |
|
expr = getattr(sympy.functions, name)(arg, evaluate=False) |
|
elif name in ["arsinh", "arcosh", "artanh"]: |
|
name = "a" + name[2:] |
|
expr = getattr(sympy.functions, name)(arg, evaluate=False) |
|
elif name in ["arcsinh", "arccosh", "arctanh"]: |
|
name = "a" + name[3:] |
|
expr = getattr(sympy.functions, name)(arg, evaluate=False) |
|
elif name == "operatorname": |
|
operatorname = func.func_normal_single_arg().func_operator_name.getText() |
|
|
|
if operatorname in ["arsinh", "arcosh", "artanh"]: |
|
operatorname = "a" + operatorname[2:] |
|
expr = getattr(sympy.functions, operatorname)(arg, evaluate=False) |
|
elif operatorname in ["arcsinh", "arccosh", "arctanh"]: |
|
operatorname = "a" + operatorname[3:] |
|
expr = getattr(sympy.functions, operatorname)(arg, evaluate=False) |
|
elif operatorname == "floor": |
|
expr = handle_floor(arg) |
|
elif operatorname == "ceil": |
|
expr = handle_ceil(arg) |
|
elif operatorname == 'eye': |
|
expr = sympy.eye(arg) |
|
elif operatorname == 'rank': |
|
expr = sympy.Integer(arg.rank()) |
|
elif operatorname in ['trace', 'tr']: |
|
expr = arg.trace() |
|
elif operatorname == 'rref': |
|
expr = arg.rref()[0] |
|
elif operatorname == 'nullspace': |
|
expr = arg.nullspace() |
|
elif operatorname == 'norm': |
|
expr = arg.norm() |
|
elif operatorname == 'cols': |
|
expr = [arg.col(i) for i in range(arg.cols)] |
|
elif operatorname == 'rows': |
|
expr = [arg.row(i) for i in range(arg.rows)] |
|
elif operatorname in ['eig', 'eigen', 'diagonalize']: |
|
expr = arg.diagonalize() |
|
elif operatorname in ['eigenvals', 'eigenvalues']: |
|
expr = arg.eigenvals() |
|
elif operatorname in ['eigenvects', 'eigenvectors']: |
|
expr = arg.eigenvects() |
|
elif operatorname in ['svd', 'SVD']: |
|
expr = arg.singular_value_decomposition() |
|
elif name in ["log", "ln"]: |
|
if func.subexpr(): |
|
if func.subexpr().atom(): |
|
base = convert_atom(func.subexpr().atom()) |
|
else: |
|
base = convert_expr(func.subexpr().expr()) |
|
elif name == "log": |
|
base = 10 |
|
elif name == "ln": |
|
base = sympy.E |
|
expr = sympy.log(arg, base, evaluate=False) |
|
elif name in ["exp", "exponentialE"]: |
|
expr = sympy.exp(arg) |
|
elif name == "floor": |
|
expr = handle_floor(arg) |
|
elif name == "ceil": |
|
expr = handle_ceil(arg) |
|
elif name == 'det': |
|
expr = arg.det() |
|
|
|
func_pow = None |
|
should_pow = True |
|
if func.supexpr(): |
|
if func.supexpr().expr(): |
|
func_pow = convert_expr(func.supexpr().expr()) |
|
else: |
|
func_pow = convert_atom(func.supexpr().atom()) |
|
|
|
if name in ["sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh", "tanh"]: |
|
if func_pow == -1: |
|
name = "a" + name |
|
should_pow = False |
|
expr = getattr(sympy.functions, name)(arg, evaluate=False) |
|
|
|
if func_pow and should_pow: |
|
expr = sympy.Pow(expr, func_pow, evaluate=False) |
|
|
|
return expr |
|
|
|
elif func.func_normal_multi_arg(): |
|
if func.L_PAREN(): |
|
args = func.func_multi_arg().getText().split(",") |
|
else: |
|
args = func.func_multi_arg_noparens().split(",") |
|
|
|
args = list(map(lambda arg: latex2sympy(arg, VARIABLE_VALUES), args)) |
|
name = func.func_normal_multi_arg().start.text[1:] |
|
|
|
if name == "operatorname": |
|
operatorname = func.func_normal_multi_arg().func_operator_name.getText() |
|
if operatorname in ["gcd", "lcm"]: |
|
expr = handle_gcd_lcm(operatorname, args) |
|
elif operatorname == 'zeros': |
|
expr = sympy.zeros(*args) |
|
elif operatorname == 'ones': |
|
expr = sympy.ones(*args) |
|
elif operatorname == 'diag': |
|
expr = sympy.diag(*args) |
|
elif operatorname == 'hstack': |
|
expr = sympy.Matrix.hstack(*args) |
|
elif operatorname == 'vstack': |
|
expr = sympy.Matrix.vstack(*args) |
|
elif operatorname in ['orth', 'ortho', 'orthogonal', 'orthogonalize']: |
|
if len(args) == 1: |
|
arg = args[0] |
|
expr = sympy.matrices.GramSchmidt([arg.col(i) for i in range(arg.cols)], True) |
|
else: |
|
expr = sympy.matrices.GramSchmidt(args, True) |
|
elif name in ["gcd", "lcm"]: |
|
expr = handle_gcd_lcm(name, args) |
|
elif name in ["max", "min"]: |
|
name = name[0].upper() + name[1:] |
|
expr = getattr(sympy.functions, name)(*args, evaluate=False) |
|
|
|
func_pow = None |
|
should_pow = True |
|
if func.supexpr(): |
|
if func.supexpr().expr(): |
|
func_pow = convert_expr(func.supexpr().expr()) |
|
else: |
|
func_pow = convert_atom(func.supexpr().atom()) |
|
|
|
if func_pow and should_pow: |
|
expr = sympy.Pow(expr, func_pow, evaluate=False) |
|
|
|
return expr |
|
elif func.atom_expr_no_supexpr(): |
|
|
|
f = sympy.Function(func.atom_expr_no_supexpr().getText()) |
|
|
|
args = func.func_common_args().getText().split(",") |
|
if args[-1] == '': |
|
args = args[:-1] |
|
args = [latex2sympy(arg, VARIABLE_VALUES) for arg in args] |
|
|
|
if func.supexpr(): |
|
if func.supexpr().expr(): |
|
expr = convert_expr(func.supexpr().expr()) |
|
else: |
|
expr = convert_atom(func.supexpr().atom()) |
|
return sympy.Pow(f(*args), expr, evaluate=False) |
|
else: |
|
return f(*args) |
|
elif func.FUNC_INT(): |
|
return handle_integral(func) |
|
elif func.FUNC_SQRT(): |
|
expr = convert_expr(func.base) |
|
if func.root: |
|
r = convert_expr(func.root) |
|
return sympy.Pow(expr, 1 / r, evaluate=False) |
|
else: |
|
return sympy.Pow(expr, sympy.S.Half, evaluate=False) |
|
elif func.FUNC_SUM(): |
|
return handle_sum_or_prod(func, "summation") |
|
elif func.FUNC_PROD(): |
|
return handle_sum_or_prod(func, "product") |
|
elif func.FUNC_LIM(): |
|
return handle_limit(func) |
|
elif func.EXP_E(): |
|
return handle_exp(func) |
|
|
|
|
|
def convert_func_arg(arg): |
|
if hasattr(arg, 'expr'): |
|
return convert_expr(arg.expr()) |
|
else: |
|
return convert_mp(arg.mp_nofunc()) |
|
|
|
|
|
def handle_integral(func): |
|
if func.additive(): |
|
integrand = convert_add(func.additive()) |
|
elif func.frac(): |
|
integrand = convert_frac(func.frac()) |
|
else: |
|
integrand = 1 |
|
|
|
int_var = None |
|
if func.DIFFERENTIAL(): |
|
int_var = get_differential_var(func.DIFFERENTIAL()) |
|
else: |
|
for sym in integrand.atoms(sympy.Symbol): |
|
s = str(sym) |
|
if len(s) > 1 and s[0] == 'd': |
|
if s[1] == '\\': |
|
int_var = sympy.Symbol(s[2:], real=is_real) |
|
else: |
|
int_var = sympy.Symbol(s[1:], real=is_real) |
|
int_sym = sym |
|
if int_var: |
|
integrand = integrand.subs(int_sym, 1) |
|
else: |
|
|
|
int_var = sympy.Symbol('x', real=is_real) |
|
|
|
if func.subexpr(): |
|
if func.subexpr().atom(): |
|
lower = convert_atom(func.subexpr().atom()) |
|
else: |
|
lower = convert_expr(func.subexpr().expr()) |
|
if func.supexpr().atom(): |
|
upper = convert_atom(func.supexpr().atom()) |
|
else: |
|
upper = convert_expr(func.supexpr().expr()) |
|
return sympy.Integral(integrand, (int_var, lower, upper)) |
|
else: |
|
return sympy.Integral(integrand, int_var) |
|
|
|
|
|
def handle_sum_or_prod(func, name): |
|
val = convert_mp(func.mp()) |
|
iter_var = convert_expr(func.subeq().equality().expr(0)) |
|
start = convert_expr(func.subeq().equality().expr(1)) |
|
if func.supexpr().expr(): |
|
end = convert_expr(func.supexpr().expr()) |
|
else: |
|
end = convert_atom(func.supexpr().atom()) |
|
|
|
if name == "summation": |
|
return sympy.Sum(val, (iter_var, start, end)) |
|
elif name == "product": |
|
return sympy.Product(val, (iter_var, start, end)) |
|
|
|
|
|
def handle_limit(func): |
|
sub = func.limit_sub() |
|
if sub.LETTER_NO_E(): |
|
var = sympy.Symbol(sub.LETTER_NO_E().getText(), real=is_real) |
|
elif sub.GREEK_CMD(): |
|
var = sympy.Symbol(sub.GREEK_CMD().getText()[1:].strip(), real=is_real) |
|
elif sub.OTHER_SYMBOL_CMD(): |
|
var = sympy.Symbol(sub.OTHER_SYMBOL_CMD().getText().strip(), real=is_real) |
|
else: |
|
var = sympy.Symbol('x', real=is_real) |
|
if sub.SUB(): |
|
direction = "-" |
|
else: |
|
direction = "+" |
|
approaching = convert_expr(sub.expr()) |
|
content = convert_mp(func.mp()) |
|
|
|
return sympy.Limit(content, var, approaching, direction) |
|
|
|
|
|
def handle_exp(func): |
|
if func.supexpr(): |
|
if func.supexpr().expr(): |
|
exp_arg = convert_expr(func.supexpr().expr()) |
|
else: |
|
exp_arg = convert_atom(func.supexpr().atom()) |
|
else: |
|
exp_arg = 1 |
|
return sympy.exp(exp_arg) |
|
|
|
|
|
def handle_gcd_lcm(f, args): |
|
""" |
|
Return the result of gcd() or lcm(), as UnevaluatedExpr |
|
|
|
f: str - name of function ("gcd" or "lcm") |
|
args: List[Expr] - list of function arguments |
|
""" |
|
|
|
args = tuple(map(sympy.nsimplify, args)) |
|
|
|
|
|
return sympy.UnevaluatedExpr(getattr(sympy, f)(args)) |
|
|
|
|
|
def handle_floor(expr): |
|
""" |
|
Apply floor() then return the floored expression. |
|
|
|
expr: Expr - sympy expression as an argument to floor() |
|
""" |
|
return sympy.functions.floor(expr, evaluate=False) |
|
|
|
|
|
def handle_ceil(expr): |
|
""" |
|
Apply ceil() then return the ceil-ed expression. |
|
|
|
expr: Expr - sympy expression as an argument to ceil() |
|
""" |
|
return sympy.functions.ceiling(expr, evaluate=False) |
|
|
|
|
|
def get_differential_var(d): |
|
text = get_differential_var_str(d.getText()) |
|
return sympy.Symbol(text, real=is_real) |
|
|
|
|
|
def get_differential_var_str(text): |
|
for i in range(1, len(text)): |
|
c = text[i] |
|
if not (c == " " or c == "\r" or c == "\n" or c == "\t"): |
|
idx = i |
|
break |
|
text = text[idx:] |
|
if text[0] == "\\": |
|
text = text[1:] |
|
return text |
|
|
|
|
|
def latex(tex): |
|
global frac_type |
|
result = sympy.latex(tex) |
|
result = result.replace(r'\frac', frac_type, -1).replace(r'\dfrac', frac_type, -1).replace(r'\tfrac', frac_type, -1) |
|
result = result.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}', -1).replace(r'\end{matrix}\right]', r'\end{bmatrix}', -1) |
|
result = result.replace(r'\left', r'', -1).replace(r'\right', r'', -1) |
|
result = result.replace(r' )', r')', -1) |
|
result = result.replace(r'\log', r'\ln', -1) |
|
return result |
|
|
|
|
|
def latex2latex(tex): |
|
result = latex2sympy(tex) |
|
|
|
if isinstance(result, list) or isinstance(result, tuple) or isinstance(result, dict): |
|
return latex(result) |
|
else: |
|
return latex(simplify(result.subs(variances).doit().doit())) |
|
|
|
|
|
|
|
latex2latex('i=I') |
|
latex2latex('j=I') |
|
|
|
for i in range(1, 10): |
|
lh = sympy.Symbol(r'\bm{I}_' + str(i), real=False) |
|
lh_m = sympy.MatrixSymbol(r'\bm{I}_' + str(i), i, i) |
|
rh = sympy.Identity(i).as_mutable() |
|
variances[lh] = rh |
|
variances[lh_m] = rh |
|
var[str(lh)] = rh |
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
tex = r"\operatorname{rows}(\begin{bmatrix}1 & 2 \\ 3 & 4\end{bmatrix})" |
|
|
|
math = latex2sympy(tex) |
|
|
|
print("latex:", tex) |
|
|
|
print("raw_math:", math) |
|
|
|
|
|
|
|
print("cal:", latex2latex(tex)) |
|
|