from transformers import AutoModelForCausalLM


pt_model = AutoModelForCausalLM.from_pretrained('.', from_flax=True)
pt_model.save_pretrained(".")