Make sure hidden state and wte weights are on same device when in parallel model. 28721e3 muelletm commited on May 27, 2023