from decimal import * import sys sys.setrecursionlimit(10 ** 6) int1 = lambda x: int(x) - 1 p2D = lambda x: print(*x, sep="\n") def MI(): return map(int, sys.stdin.readline().split()) def LI(): return list(map(int, sys.stdin.readline().split())) def LLI(rows_number): return [LI() for _ in range(rows_number)] def main(): s=input() t=list(s) mx=max(t) idx=s.rfind(mx) if idx>0:t[0],t[idx]=t[idx],t[0] print("".join(t)) main()